From a4b02656282d0321ee33ec6b3d27616b66369219 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 27 Mar 2023 00:51:56 +1100 Subject: [PATCH] use a chan instead of a context in Connection.HandshakeComplete (#3709) --- client.go | 2 +- client_test.go | 24 ++++++++-------- connection.go | 6 ++-- connection_test.go | 10 +++---- http3/client.go | 4 +-- http3/client_test.go | 36 ++++++++++++------------ integrationtests/self/early_data_test.go | 2 +- integrationtests/self/zero_rtt_test.go | 4 +-- interface.go | 7 +++-- internal/mocks/quic/early_conn.go | 4 +-- interop/http09/client.go | 2 +- mock_quic_conn_test.go | 4 +-- server.go | 2 +- server_test.go | 24 ++++++++-------- 14 files changed, 66 insertions(+), 65 deletions(-) diff --git a/client.go b/client.go index b05f0ab2..bea685b7 100644 --- a/client.go +++ b/client.go @@ -325,7 +325,7 @@ func (c *client) dial(ctx context.Context) error { case <-earlyConnChan: // ready to send 0-RTT data return nil - case <-c.conn.HandshakeComplete().Done(): + case <-c.conn.HandshakeComplete(): // handshake successfully completed return nil } diff --git a/client_test.go b/client_test.go index e9c323c6..04874471 100644 --- a/client_test.go +++ b/client_test.go @@ -130,7 +130,7 @@ var _ = Describe("Client", func() { remoteAddrChan <- sconn.RemoteAddr().String() conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) @@ -163,7 +163,7 @@ var _ = Describe("Client", func() { hostnameChan <- tlsConf.ServerName conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } tlsConf.ServerName = "foobar" @@ -195,7 +195,7 @@ var _ = Describe("Client", func() { ) quicConn { hostnameChan <- tlsConf.ServerName conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().run() return conn } @@ -235,9 +235,9 @@ var _ = Describe("Client", func() { Expect(enable0RTT).To(BeFalse()) conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run().Do(func() { close(run) }) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) + c := make(chan struct{}) + close(c) + conn.EXPECT().HandshakeComplete().Return(c) return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -278,7 +278,7 @@ var _ = Describe("Client", func() { Expect(enable0RTT).To(BeTrue()) conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run().Do(func() { <-done }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().earlyConnReady().Return(readyChan) return conn } @@ -325,7 +325,7 @@ var _ = Describe("Client", func() { ) quicConn { conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run().Return(testErr) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) @@ -350,7 +350,7 @@ var _ = Describe("Client", func() { conn.EXPECT().run().Do(func() { <-connRunning }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) newClientConnection = func( _ sendConn, _ connRunner, @@ -425,7 +425,7 @@ var _ = Describe("Client", func() { conn.EXPECT().run().Do(func() { <-run }) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) done := make(chan struct{}) go func() { @@ -546,7 +546,7 @@ var _ = Describe("Client", func() { // TODO: check connection IDs? conn := NewMockQuicConn(mockCtrl) conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) @@ -580,7 +580,7 @@ var _ = Describe("Client", func() { versionP protocol.VersionNumber, ) quicConn { conn := NewMockQuicConn(mockCtrl) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) if counter == 0 { Expect(pn).To(BeZero()) Expect(hasNegotiatedVersion).To(BeFalse()) diff --git a/connection.go b/connection.go index 50db2964..eb16ece5 100644 --- a/connection.go +++ b/connection.go @@ -696,8 +696,8 @@ func (s *connection) earlyConnReady() <-chan struct{} { return s.earlyConnReadyChan } -func (s *connection) HandshakeComplete() context.Context { - return s.handshakeCtx +func (s *connection) HandshakeComplete() <-chan struct{} { + return s.handshakeCtx.Done() } func (s *connection) Context() context.Context { @@ -2195,7 +2195,7 @@ func (s *connection) GetVersion() protocol.VersionNumber { } func (s *connection) NextConnection() Connection { - <-s.HandshakeComplete().Done() + <-s.HandshakeComplete() s.streamsMap.UseResetMaps() return s } diff --git a/connection_test.go b/connection_test.go index 1e16e5c2..27c476df 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1834,9 +1834,9 @@ var _ = Describe("Connection", func() { conn.run() }() handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Consistently(handshakeCtx).ShouldNot(BeClosed()) close(finishHandshake) - Eventually(handshakeCtx.Done()).Should(BeClosed()) + Eventually(handshakeCtx).Should(BeClosed()) // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() @@ -1865,7 +1865,7 @@ var _ = Describe("Connection", func() { }() handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Consistently(handshakeCtx).ShouldNot(BeClosed()) close(finishHandshake) var frames []*ackhandler.Frame Eventually(func() []*ackhandler.Frame { @@ -1908,10 +1908,10 @@ var _ = Describe("Connection", func() { conn.run() }() handshakeCtx := conn.HandshakeComplete() - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Consistently(handshakeCtx).ShouldNot(BeClosed()) mconn.EXPECT().Write(gomock.Any()) conn.closeLocal(errors.New("handshake error")) - Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Consistently(handshakeCtx).ShouldNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) }) diff --git a/http3/client.go b/http3/client.go index c63505e1..63200db8 100644 --- a/http3/client.go +++ b/http3/client.go @@ -266,7 +266,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } else { // wait for the handshake to complete select { - case <-conn.HandshakeComplete().Done(): + case <-conn.HandshakeComplete(): case <-req.Context().Done(): return nil, req.Context().Err() } @@ -449,7 +449,7 @@ func (c *client) HandshakeComplete() bool { return false } select { - case <-(*conn).HandshakeComplete().Done(): + case <-(*conn).HandshakeComplete(): return true default: return false diff --git a/http3/client_test.go b/http3/client_test.go index 6851e84c..ebea0047 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -26,10 +26,10 @@ import ( var _ = Describe("Client", func() { var ( - cl *client - req *http.Request - origDialAddr = dialAddr - handshakeCtx context.Context // an already canceled context + cl *client + req *http.Request + origDialAddr = dialAddr + handshakeChan <-chan struct{} // a closed chan ) BeforeEach(func() { @@ -43,9 +43,9 @@ var _ = Describe("Client", func() { req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - handshakeCtx = ctx + ch := make(chan struct{}) + close(ch) + handshakeChan = ch }) AfterEach(func() { @@ -202,7 +202,7 @@ var _ = Describe("Client", func() { }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { @@ -328,7 +328,7 @@ var _ = Describe("Client", func() { }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil @@ -434,7 +434,7 @@ var _ = Describe("Client", func() { }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil @@ -696,7 +696,7 @@ var _ = Describe("Client", func() { testErr := errors.New("stream open error") conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) @@ -721,7 +721,7 @@ var _ = Describe("Client", func() { It("returns a response", func() { rspBuf := bytes.NewBuffer(getResponse(418)) gomock.InOrder( - conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) @@ -739,7 +739,7 @@ var _ = Describe("Client", func() { It("doesn't close the request stream, with DontCloseRequestStream set", func() { rspBuf := bytes.NewBuffer(getResponse(418)) gomock.InOrder( - conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) @@ -758,7 +758,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { strBuf = &bytes.Buffer{} gomock.InOrder( - conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), ) body := &mockBody{} @@ -864,7 +864,7 @@ var _ = Describe("Client", func() { It("cancels a request while waiting for the handshake to complete", func() { ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) errChan := make(chan error) go func() { @@ -879,7 +879,7 @@ var _ = Describe("Client", func() { It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -910,7 +910,7 @@ var _ = Describe("Client", func() { ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} @@ -930,7 +930,7 @@ var _ = Describe("Client", func() { Context("gzip compression", func() { BeforeEach(func() { - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) }) It("adds the gzip header to requests", func() { diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index 69ed0484..063f8622 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -40,7 +40,7 @@ var _ = Describe("early data", func() { Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) // make sure the Write finished before the handshake completed - Expect(conn.HandshakeComplete().Done()).ToNot(BeClosed()) + Expect(conn.HandshakeComplete()).ToNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) }() serverPort := ln.Addr().(*net.UDPAddr).Port diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index ae99aa9e..8f6813ed 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -142,7 +142,7 @@ var _ = Describe("0-RTT", func() { _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - <-conn.HandshakeComplete().Done() + <-conn.HandshakeComplete() Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn conn.CloseWithError(0, "") @@ -302,7 +302,7 @@ var _ = Describe("0-RTT", func() { Expect(firstStr.Close()).To(Succeed()) // wait for the handshake to complete - Eventually(conn.HandshakeComplete().Done()).Should(BeClosed()) + Eventually(conn.HandshakeComplete()).Should(BeClosed()) str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) diff --git a/interface.go b/interface.go index 3eebe4a6..b700e7c1 100644 --- a/interface.go +++ b/interface.go @@ -197,9 +197,10 @@ type EarlyConnection interface { Connection // HandshakeComplete blocks until the handshake completes (or fails). - // Data sent before completion of the handshake is encrypted with 1-RTT keys. - // Note that the client's identity hasn't been verified yet. - HandshakeComplete() context.Context + // For the client, data sent before completion of the handshake is encrypted with 0-RTT keys. + // For the serfer, data sent before completion of the handshake is encrypted with 1-RTT keys, + // however the client's identity is only verified once the handshake completes. + HandshakeComplete() <-chan struct{} NextConnection() Connection } diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 7915144f..174c70de 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -110,10 +110,10 @@ func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call { } // HandshakeComplete mocks base method. -func (m *MockEarlyConnection) HandshakeComplete() context.Context { +func (m *MockEarlyConnection) HandshakeComplete() <-chan struct{} { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandshakeComplete") - ret0, _ := ret[0].(context.Context) + ret0, _ := ret[0].(<-chan struct{}) return ret0 } diff --git a/interop/http09/client.go b/interop/http09/client.go index 4f2be836..6cb326bb 100644 --- a/interop/http09/client.go +++ b/interop/http09/client.go @@ -96,7 +96,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, c.dialErr } if req.Method != MethodGet0RTT { - <-c.conn.HandshakeComplete().Done() + <-c.conn.HandshakeComplete() } return c.doRequest(req) } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index e4e75c0f..c51ce776 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -123,10 +123,10 @@ func (mr *MockQuicConnMockRecorder) GetVersion() *gomock.Call { } // HandshakeComplete mocks base method. -func (m *MockQuicConn) HandshakeComplete() context.Context { +func (m *MockQuicConn) HandshakeComplete() <-chan struct{} { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandshakeComplete") - ret0, _ := ret[0].(context.Context) + ret0, _ := ret[0].(<-chan struct{}) return ret0 } diff --git a/server.go b/server.go index 734d617f..7c0ec099 100644 --- a/server.go +++ b/server.go @@ -536,7 +536,7 @@ func (s *baseServer) handleNewConn(conn quicConn) { } else { // wait until the handshake is complete (or fails) select { - case <-conn.HandshakeComplete().Done(): + case <-conn.HandshakeComplete(): case <-connCtx.Done(): return } diff --git a/server_test.go b/server_test.go index 6d4217be..28e59f87 100644 --- a/server_test.go +++ b/server_test.go @@ -298,7 +298,7 @@ var _ = Describe("Server", func() { conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() { close(run) }) conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } @@ -494,7 +494,7 @@ var _ = Describe("Server", func() { conn.EXPECT().handlePacket(p) conn.EXPECT().run().Do(func() { close(run) }) conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } @@ -546,7 +546,7 @@ var _ = Describe("Server", func() { conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) conn.EXPECT().run().MaxTimes(1) conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) - conn.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) return conn } @@ -626,9 +626,9 @@ var _ = Describe("Server", func() { conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() conn.EXPECT().Context().Return(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) + c := make(chan struct{}) + close(c) + conn.EXPECT().HandshakeComplete().Return(c) return conn } @@ -695,9 +695,9 @@ var _ = Describe("Server", func() { conn.EXPECT().handlePacket(p) conn.EXPECT().run() conn.EXPECT().Context().Return(ctx) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn.EXPECT().HandshakeComplete().Return(ctx) + c := make(chan struct{}) + close(c) + conn.EXPECT().HandshakeComplete().Return(c) close(connCreated) return conn } @@ -972,7 +972,7 @@ var _ = Describe("Server", func() { close(done) }() - ctx, cancel := context.WithCancel(context.Background()) // handshake context + handshakeChan := make(chan struct{}) serv.newConn = func( _ sendConn, runner connRunner, @@ -992,7 +992,7 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().HandshakeComplete().Return(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().run().Do(func() {}) conn.EXPECT().Context().Return(context.Background()) return conn @@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() { &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) - cancel() // complete the handshake + close(handshakeChan) // complete the handshake Eventually(done).Should(BeClosed()) }) })