diff --git a/connection.go b/connection.go index e01e89af..ed01defa 100644 --- a/connection.go +++ b/connection.go @@ -218,6 +218,9 @@ type connection struct { datagramQueue *datagramQueue + connStateMutex sync.Mutex + connState ConnectionState + logID string tracer logging.ConnectionTracer logger utils.Logger @@ -545,6 +548,7 @@ func (s *connection) preSetup() { s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + s.connState.Version = s.version } // run the connection main loop @@ -738,11 +742,10 @@ func (s *connection) supportsDatagrams() bool { } func (s *connection) ConnectionState() ConnectionState { - return ConnectionState{ - TLS: s.cryptoStreamHandler.ConnectionState(), - SupportsDatagrams: s.supportsDatagrams(), - Version: s.version, - } + s.connStateMutex.Lock() + defer s.connStateMutex.Unlock() + s.connState.TLS = s.cryptoStreamHandler.ConnectionState() + return s.connState } // Time when the next keep-alive packet should be sent. @@ -1678,6 +1681,9 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.streamsMap.UpdateLimits(params) + s.connStateMutex.Lock() + s.connState.SupportsDatagrams = s.supportsDatagrams() + s.connStateMutex.Unlock() } func (s *connection) handleTransportParameters(params *wire.TransportParameters) { @@ -1696,6 +1702,10 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters) // the client's transport parameters. close(s.earlyConnReadyChan) } + + s.connStateMutex.Lock() + s.connState.SupportsDatagrams = s.supportsDatagrams() + s.connStateMutex.Unlock() } func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { diff --git a/go.mod b/go.mod index 91f8b7f8..2fa87b72 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 github.com/marten-seemann/qpack v0.3.0 - github.com/marten-seemann/qtls-go1-18 v0.1.3 - github.com/marten-seemann/qtls-go1-19 v0.1.1 + github.com/marten-seemann/qtls-go1-18 v0.1.4 + github.com/marten-seemann/qtls-go1-19 v0.1.2 github.com/onsi/ginkgo/v2 v2.2.0 github.com/onsi/gomega v1.20.1 golang.org/x/crypto v0.4.0 diff --git a/go.sum b/go.sum index 1099f435..016c4dce 100644 --- a/go.sum +++ b/go.sum @@ -70,10 +70,10 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= -github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= -github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= -github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-18 v0.1.4 h1:ogomB+lWV3Vmwiu6RTwDVTMGx+9j7SEi98e8QB35Its= +github.com/marten-seemann/qtls-go1-18 v0.1.4/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.2 h1:ZevAEqKXH0bZmoOBPiqX2h5rhQ7cbZi+X+rlq2JUbCE= +github.com/marten-seemann/qtls-go1-19 v0.1.2/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/http3/request.go b/http3/request.go index 0b9a7278..f5a8381a 100644 --- a/http3/request.go +++ b/http3/request.go @@ -1,7 +1,6 @@ package http3 import ( - "crypto/tls" "errors" "net/http" "net/url" @@ -101,7 +100,6 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { ContentLength: contentLength, Host: authority, RequestURI: requestURI, - TLS: &tls.ConnectionState{}, }, nil } diff --git a/http3/request_test.go b/http3/request_test.go index 46a8a93e..d9d57998 100644 --- a/http3/request_test.go +++ b/http3/request_test.go @@ -30,7 +30,6 @@ var _ = Describe("Request", func() { Expect(req.Body).To(BeNil()) Expect(req.Host).To(Equal("quic.clemente.io")) Expect(req.RequestURI).To(Equal("/foo")) - Expect(req.TLS).ToNot(BeNil()) }) It("parses path with leading double slashes", func() { diff --git a/http3/server.go b/http3/server.go index 0455895e..ac38dd51 100644 --- a/http3/server.go +++ b/http3/server.go @@ -272,7 +272,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { baseConf := ConfigureTLSConfig(tlsConf) quicConf := s.QuicConfig if quicConf == nil { - quicConf = &quic.Config{} + quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }} } else { quicConf = s.QuicConfig.Clone() } @@ -570,6 +570,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q return newStreamError(errorGeneralProtocolError, err) } + connState := conn.ConnectionState().TLS.ConnectionState + req.TLS = &connState req.RemoteAddr = conn.RemoteAddr().String() body := newRequestBody(newStream(str, onFrameError)) req.Body = body diff --git a/http3/server_test.go b/http3/server_test.go index 72758981..4e46ecfb 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -163,6 +163,7 @@ var _ = Describe("Server", func() { addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() conn.EXPECT().LocalAddr().AnyTimes() + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() }) It("calls the HTTP handler function", func() { @@ -632,6 +633,7 @@ var _ = Describe("Server", func() { conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() conn.EXPECT().LocalAddr().AnyTimes() + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() }) AfterEach(func() { testDone <- struct{}{} }) diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 1b8df6ac..a8141f85 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -81,10 +81,10 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= -github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= -github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= -github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-18 v0.1.4 h1:ogomB+lWV3Vmwiu6RTwDVTMGx+9j7SEi98e8QB35Its= +github.com/marten-seemann/qtls-go1-18 v0.1.4/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.2 h1:ZevAEqKXH0bZmoOBPiqX2h5rhQ7cbZi+X+rlq2JUbCE= +github.com/marten-seemann/qtls-go1-19 v0.1.2/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 60f173ca..1c965299 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "net" "time" @@ -18,54 +19,33 @@ import ( var _ = Describe("Handshake RTT tests", func() { var ( proxy *quicproxy.QuicProxy - server quic.Listener serverConfig *quic.Config serverTLSConfig *tls.Config - testStartedAt time.Time - acceptStopped chan struct{} ) - rtt := 400 * time.Millisecond + const rtt = 400 * time.Millisecond BeforeEach(func() { - acceptStopped = make(chan struct{}) serverConfig = getQuicConfig(nil) serverTLSConfig = getTLSConfig() }) AfterEach(func() { Expect(proxy.Close()).To(Succeed()) - Expect(server.Close()).To(Succeed()) - <-acceptStopped }) - runServerAndProxy := func() { + runProxy := func(serverAddr net.Addr) { var err error - // start the server - server, err = quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) - Expect(err).ToNot(HaveOccurred()) // start the proxy proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: server.Addr().String(), + RemoteAddr: serverAddr.String(), DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, }) Expect(err).ToNot(HaveOccurred()) - - testStartedAt = time.Now() - - go func() { - defer GinkgoRecover() - defer close(acceptStopped) - for { - if _, err := server.Accept(context.Background()); err != nil { - return - } - } - }() } - expectDurationInRTTs := func(num int) { - testDuration := time.Since(testStartedAt) + expectDurationInRTTs := func(startTime time.Time, num int) { + testDuration := time.Since(startTime) rtts := float32(testDuration) / float32(rtt) Expect(rtts).To(SatisfyAll( BeNumerically(">=", num), @@ -78,15 +58,19 @@ var _ = Describe("Handshake RTT tests", func() { Skip("Test requires at least 2 supported versions.") } serverConfig.Versions = protocol.SupportedVersions[:1] - runServerAndProxy() - clientConfig := getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}) - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( proxy.LocalAddr().String(), getTLSClientConfig(), - clientConfig, + getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), ) Expect(err).To(HaveOccurred()) - expectDurationInRTTs(1) + expectDurationInRTTs(startTime, 1) }) var clientConfig *quic.Config @@ -102,36 +86,114 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(2) + expectDurationInRTTs(startTime, 2) }) It("establishes a connection in 1 RTT when the server doesn't require a token", func() { - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(1) + expectDurationInRTTs(startTime, 1) }) It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(2) + expectDurationInRTTs(startTime, 2) + }) + + It("receives the first message from the server after 2 RTTs, when the server uses ListenAddr", func() { + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + go func() { + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + clientConfig, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + expectDurationInRTTs(startTime, 2) + }) + + It("receives the first message from the server after 1 RTT, when the server uses ListenAddrEarly", func() { + ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + go func() { + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + // Check the ALPN now. This is probably what an application would do. + // It makes sure that ConnectionState does not block until the handshake completes. + Expect(conn.ConnectionState().TLS.NegotiatedProtocol).To(Equal(alpn)) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + clientConfig, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + expectDurationInRTTs(startTime, 1) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index bd4e9342..fee9594c 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -110,19 +110,20 @@ var _ = Describe("0-RTT", func() { clientConf *quic.Config, testdata []byte, // data to transfer ) { - // now dial the second connection, and use 0-RTT to send some data + // accept the second connection, and receive the data sent in 0-RTT done := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testdata)) + Expect(str.Close()).To(Succeed()) Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) + <-conn.Context().Done() close(done) }() @@ -136,13 +137,15 @@ var _ = Describe("0-RTT", func() { ) Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") - str, err := conn.OpenUniStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) <-conn.HandshakeComplete().Done() Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn + conn.CloseWithError(0, "") Eventually(done).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) }