rename the EarlySession to EarlyConnection

This commit is contained in:
Marten Seemann 2022-03-26 15:12:28 +01:00
parent fda9f72161
commit 6aaa9a817c
17 changed files with 232 additions and 228 deletions

View file

@ -34,7 +34,7 @@ var defaultQuicConfig = &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
var dialAddr = quic.DialAddrEarlyContext
@ -59,7 +59,7 @@ type client struct {
decoder *qpack.Decoder
hostname string
session quic.EarlySession
session quic.EarlyConnection
logger utils.Logger
}

View file

@ -65,7 +65,7 @@ var _ = Describe("Client", func() {
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
@ -80,7 +80,7 @@ var _ = Describe("Client", func() {
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
@ -100,7 +100,7 @@ var _ = Describe("Client", func() {
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
Expect(host).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
@ -121,7 +121,7 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
var dialerCalled bool
dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
Expect(ctxP).To(Equal(ctx))
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
@ -141,7 +141,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
Expect(quicConf.EnableDatagrams).To(BeTrue())
return nil, testErr
}
@ -153,14 +153,14 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("closes correctly if session was not created", func() {
It("closes correctly if connection was not created", func() {
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.Close()).To(Succeed())
@ -178,7 +178,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
@ -189,7 +189,7 @@ var _ = Describe("Client", func() {
Context("control stream handling", func() {
var (
request *http.Request
sess *mockquic.MockEarlySession
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
@ -201,11 +201,13 @@ var _ = Describe("Client", func() {
defer GinkgoRecover()
close(settingsFrameWritten)
})
sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
@ -222,16 +224,16 @@ var _ = Describe("Client", func() {
(&settingsFrame{}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
@ -247,10 +249,10 @@ var _ = Describe("Client", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
@ -270,10 +272,10 @@ var _ = Describe("Client", func() {
close(done)
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
@ -288,15 +290,15 @@ var _ = Describe("Client", func() {
(&dataFrame{}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done)
@ -314,15 +316,15 @@ var _ = Describe("Client", func() {
buf.Write(b.Bytes()[:b.Len()-1])
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorFrameError))
close(done)
@ -337,15 +339,15 @@ var _ = Describe("Client", func() {
quicvarint.Write(buf, streamTypePushStream)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorIDError))
close(done)
@ -362,16 +364,16 @@ var _ = Describe("Client", func() {
(&settingsFrame{Datagram: true}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorSettingsError))
Expect(reason).To(Equal("missing QUIC Datagram support"))
@ -387,7 +389,7 @@ var _ = Describe("Client", func() {
var (
request *http.Request
str *mockquic.MockStream
sess *mockquic.MockEarlySession
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
@ -446,13 +448,15 @@ var _ = Describe("Client", func() {
close(settingsFrameWritten)
}) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
@ -465,9 +469,9 @@ var _ = Describe("Client", func() {
It("errors if it can't open a stream", func() {
testErr := errors.New("stream open error")
sess.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
})
@ -476,7 +480,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("stream open error")
request.Method = MethodGet0RTT
// don't EXPECT any calls to HandshakeComplete()
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
str.EXPECT().Close()
@ -492,9 +496,9 @@ var _ = Describe("Client", func() {
It("returns a response", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
@ -512,8 +516,8 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
body.SetData([]byte("request body"))
@ -570,7 +574,7 @@ var _ = Describe("Client", func() {
(&dataFrame{Length: 0x6}).Write(buf)
buf.Write([]byte("foobar"))
str.EXPECT().Close().Do(func() { close(done) })
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
@ -583,7 +587,7 @@ var _ = Describe("Client", func() {
It("closes the connection when the first frame is not a HEADERS frame", func() {
buf := &bytes.Buffer{}
(&dataFrame{Length: 0x42}).Write(buf)
sess.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any())
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any())
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
@ -609,7 +613,7 @@ var _ = Describe("Client", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error)
go func() {
@ -624,8 +628,8 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -653,9 +657,9 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -673,11 +677,11 @@ var _ = Describe("Client", func() {
Context("gzip compression", func() {
BeforeEach(func() {
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
@ -694,7 +698,7 @@ var _ = Describe("Client", func() {
It("doesn't add gzip if the header disable it", func() {
client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
Expect(err).ToNot(HaveOccurred())
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
@ -709,8 +713,8 @@ var _ = Describe("Client", func() {
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
@ -735,8 +739,8 @@ var _ = Describe("Client", func() {
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()

View file

@ -50,7 +50,7 @@ type RoundTripper struct {
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.

View file

@ -61,7 +61,7 @@ var _ = Describe("RoundTripper", func() {
var (
rt *RoundTripper
req1 *http.Request
session *mockquic.MockEarlySession
conn *mockquic.MockEarlyConnection
handshakeCtx context.Context // an already canceled context
)
@ -80,12 +80,12 @@ var _ = Describe("RoundTripper", func() {
origDialAddr := dialAddr
BeforeEach(func() {
session = mockquic.NewMockEarlySession(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
origDialAddr = dialAddr
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return session, nil
return conn, nil
}
})
@ -98,14 +98,14 @@ var _ = Describe("RoundTripper", func() {
testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
session.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-closed
return nil, errors.New("test done")
}).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
receivedConfig = config
return nil, errors.New("handshake error")
}
@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
dialed = true
return nil, errors.New("handshake error")
}
@ -140,14 +140,14 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-closed
return nil, errors.New("test done")
}).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)

View file

@ -334,7 +334,7 @@ func (s *Server) removeListener(l *quic.EarlyListener) {
s.mutex.Unlock()
}
func (s *Server) handleConn(sess quic.EarlySession) {
func (s *Server) handleConn(sess quic.EarlyConnection) {
decoder := qpack.NewDecoder(nil)
// send a SETTINGS frame
@ -381,7 +381,7 @@ func (s *Server) handleConn(sess quic.EarlySession) {
}
}
func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) {
func (s *Server) handleUnidirectionalStreams(sess quic.EarlyConnection) {
for {
str, err := sess.AcceptUniStream(context.Background())
if err != nil {

View file

@ -107,7 +107,7 @@ var _ = Describe("Server", func() {
var (
qpackDecoder *qpack.Decoder
str *mockquic.MockStream
sess *mockquic.MockEarlySession
conn *mockquic.MockEarlyConnection
exampleGetRequest *http.Request
examplePostRequest *http.Request
)
@ -164,10 +164,10 @@ var _ = Describe("Server", func() {
qpackDecoder = qpack.NewDecoder(nil)
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
sess.EXPECT().RemoteAddr().Return(addr).AnyTimes()
sess.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})
It("calls the HTTP handler function", func() {
@ -183,7 +183,7 @@ var _ = Describe("Server", func() {
}).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
Expect(s.handleRequest(sess, str, qpackDecoder, nil)).To(Equal(requestError{}))
Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
var req *http.Request
Eventually(requestChan).Should(Receive(&req))
Expect(req.Host).To(Equal("www.example.com"))
@ -200,7 +200,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(sess, str, qpackDecoder, nil)
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
@ -217,7 +217,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(sess, str, qpackDecoder, nil)
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
@ -234,22 +234,22 @@ var _ = Describe("Server", func() {
str.EXPECT().Write([]byte("foobar"))
// don't EXPECT CancelRead()
serr := s.handleRequest(sess, str, qpackDecoder, nil)
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
})
Context("control stream handling", func() {
var sess *mockquic.MockEarlySession
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})
BeforeEach(func() {
sess = mockquic.NewMockEarlySession(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
sess.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
sess.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})
AfterEach(func() { testDone <- struct{}{} })
@ -260,15 +260,15 @@ var _ = Describe("Server", func() {
(&settingsFrame{}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(sess)
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError
s.handleConn(conn)
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
@ -284,14 +284,14 @@ var _ = Describe("Server", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(sess)
s.handleConn(conn)
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
})
}
@ -306,14 +306,14 @@ var _ = Describe("Server", func() {
close(done)
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -323,20 +323,20 @@ var _ = Describe("Server", func() {
(&dataFrame{}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done)
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -348,20 +348,20 @@ var _ = Describe("Server", func() {
buf.Write(b.Bytes()[:b.Len()-1])
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorFrameError))
close(done)
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -371,20 +371,20 @@ var _ = Describe("Server", func() {
(&dataFrame{}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorStreamCreationError))
close(done)
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -395,45 +395,45 @@ var _ = Describe("Server", func() {
(&settingsFrame{Datagram: true}).Write(buf)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
defer GinkgoRecover()
Expect(code).To(BeEquivalentTo(errorSettingsError))
Expect(reason).To(Equal("missing QUIC Datagram support"))
close(done)
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
})
Context("stream- and connection-level errors", func() {
var sess *mockquic.MockEarlySession
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})
BeforeEach(func() {
testDone = make(chan struct{})
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
sess = mockquic.NewMockEarlySession(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
sess.EXPECT().RemoteAddr().Return(addr).AnyTimes()
sess.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})
AfterEach(func() { testDone <- struct{}{} })
@ -456,7 +456,7 @@ var _ = Describe("Server", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
str.EXPECT().Close().Do(func() { close(done) })
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
@ -478,7 +478,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -493,7 +493,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess)
s.handleConn(conn)
Consistently(handlerCalled).ShouldNot(BeClosed())
})
@ -511,11 +511,11 @@ var _ = Describe("Server", func() {
}).AnyTimes()
done := make(chan struct{})
sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
Expect(code).To(Equal(quic.ApplicationErrorCode(errorFrameUnexpected)))
close(done)
})
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
@ -538,7 +538,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess)
s.handleConn(conn)
Eventually(done).Should(BeClosed())
})
})
@ -560,7 +560,7 @@ var _ = Describe("Server", func() {
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
serr := s.handleRequest(sess, str, qpackDecoder, nil)
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
Eventually(handlerCalled).Should(BeClosed())
})
@ -583,7 +583,7 @@ var _ = Describe("Server", func() {
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
serr := s.handleRequest(sess, str, qpackDecoder, nil)
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
Eventually(handlerCalled).Should(BeClosed())
})