diff --git a/http3/client.go b/http3/client.go index 9bacd679..d77efc04 100644 --- a/http3/client.go +++ b/http3/client.go @@ -250,7 +250,7 @@ func (c *client) doRequest( return nil, newConnError(errorGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(c.session.ConnectionState()) + connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS) res := &http.Response{ Proto: "HTTP/3", ProtoMajor: 3, diff --git a/http3/client_test.go b/http3/client_test.go index c107a9cb..29b6909e 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -16,7 +16,6 @@ import ( mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/marten-seemann/qpack" @@ -240,7 +239,7 @@ var _ = Describe("Client", func() { gomock.InOrder( sess.EXPECT().HandshakeComplete().Return(handshakeCtx), sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}), + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() @@ -410,7 +409,7 @@ var _ = Describe("Client", func() { req := request.WithContext(ctx) sess.EXPECT().HandshakeComplete().Return(handshakeCtx) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -473,7 +472,7 @@ var _ = Describe("Client", func() { It("decompresses the response", func() { sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") @@ -499,7 +498,7 @@ var _ = Describe("Client", func() { It("only decompresses the response if the response contains the right content-encoding header", func() { sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Write([]byte("not gzipped")) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index d1c5a969..cd387044 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -170,7 +170,7 @@ var _ = Describe("Handshake tests", func() { data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - Expect(sess.ConnectionState().CipherSuite).To(Equal(suiteID)) + Expect(sess.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) Expect(sess.CloseWithError(0, "")).To(Succeed()) }) } @@ -369,7 +369,7 @@ var _ = Describe("Handshake tests", func() { sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) cs := sess.ConnectionState() - Expect(cs.NegotiatedProtocol).To(Equal(alpn)) + Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) close(done) }() @@ -381,7 +381,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) defer sess.CloseWithError(0, "") cs := sess.ConnectionState() - Expect(cs.NegotiatedProtocol).To(Equal(alpn)) + Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) Eventually(done).Should(BeClosed()) Expect(ln.Close()).To(Succeed()) }) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 0ab8b434..fbccd4a1 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -65,11 +65,11 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) var sessionKey string Eventually(puts).Should(Receive(&sessionKey)) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) sess, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -78,11 +78,11 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) - Expect(sess.ConnectionState().DidResume).To(BeTrue()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeTrue()) serverSess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeTrue()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeTrue()) }) It("doesn't use session resumption, if the config disables it", func() { @@ -104,11 +104,11 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) sess, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -116,10 +116,10 @@ var _ = Describe("TLS session resumption", func() { nil, ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index b4d8ca3a..3ab95efb 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -86,7 +86,7 @@ var _ = Describe("0-RTT", func() { data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testdata)) - Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT)) + Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT)) close(done) }() @@ -101,7 +101,7 @@ var _ = Describe("0-RTT", func() { _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT)) + Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT)) Eventually(done).Should(BeClosed()) } diff --git a/interface.go b/interface.go index a3846d11..19d38c74 100644 --- a/interface.go +++ b/interface.go @@ -141,8 +141,6 @@ type StreamError interface { ErrorCode() ErrorCode } -type ConnectionState = handshake.ConnectionState - // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. @@ -264,6 +262,11 @@ type Config struct { Tracer logging.Tracer } +// ConnectionState records basic details about a QUIC connection +type ConnectionState struct { + TLS handshake.ConnectionState +} + // A Listener for incoming QUIC connections type Listener interface { // Close the server. All active sessions will be closed. diff --git a/session.go b/session.go index 7b532913..384530d9 100644 --- a/session.go +++ b/session.go @@ -633,7 +633,9 @@ func (s *session) Context() context.Context { } func (s *session) ConnectionState() ConnectionState { - return s.cryptoStreamHandler.ConnectionState() + return ConnectionState{ + TLS: s.cryptoStreamHandler.ConnectionState(), + } } // Time when the next keep-alive packet should be sent.