use a chan instead of a context in Connection.HandshakeComplete (#3709)

This commit is contained in:
Marten Seemann 2023-03-27 00:51:56 +11:00 committed by GitHub
parent 41ddaa0262
commit a4b0265628
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 65 deletions

View file

@ -325,7 +325,7 @@ func (c *client) dial(ctx context.Context) error {
case <-earlyConnChan:
// ready to send 0-RTT data
return nil
case <-c.conn.HandshakeComplete().Done():
case <-c.conn.HandshakeComplete():
// handshake successfully completed
return nil
}

View file

@ -130,7 +130,7 @@ var _ = Describe("Client", func() {
remoteAddrChan <- sconn.RemoteAddr().String()
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
@ -163,7 +163,7 @@ var _ = Describe("Client", func() {
hostnameChan <- tlsConf.ServerName
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
tlsConf.ServerName = "foobar"
@ -195,7 +195,7 @@ var _ = Describe("Client", func() {
) quicConn {
hostnameChan <- tlsConf.ServerName
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run()
return conn
}
@ -235,9 +235,9 @@ var _ = Describe("Client", func() {
Expect(enable0RTT).To(BeFalse())
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run().Do(func() { close(run) })
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn.EXPECT().HandshakeComplete().Return(ctx)
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
@ -278,7 +278,7 @@ var _ = Describe("Client", func() {
Expect(enable0RTT).To(BeTrue())
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run().Do(func() { <-done })
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(readyChan)
return conn
}
@ -325,7 +325,7 @@ var _ = Describe("Client", func() {
) quicConn {
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run().Return(testErr)
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
@ -350,7 +350,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().run().Do(func() {
<-connRunning
})
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
newClientConnection = func(
_ sendConn,
_ connRunner,
@ -425,7 +425,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().run().Do(func() {
<-run
})
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
done := make(chan struct{})
go func() {
@ -546,7 +546,7 @@ var _ = Describe("Client", func() {
// TODO: check connection IDs?
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
@ -580,7 +580,7 @@ var _ = Describe("Client", func() {
versionP protocol.VersionNumber,
) quicConn {
conn := NewMockQuicConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
if counter == 0 {
Expect(pn).To(BeZero())
Expect(hasNegotiatedVersion).To(BeFalse())

View file

@ -696,8 +696,8 @@ func (s *connection) earlyConnReady() <-chan struct{} {
return s.earlyConnReadyChan
}
func (s *connection) HandshakeComplete() context.Context {
return s.handshakeCtx
func (s *connection) HandshakeComplete() <-chan struct{} {
return s.handshakeCtx.Done()
}
func (s *connection) Context() context.Context {
@ -2195,7 +2195,7 @@ func (s *connection) GetVersion() protocol.VersionNumber {
}
func (s *connection) NextConnection() Connection {
<-s.HandshakeComplete().Done()
<-s.HandshakeComplete()
s.streamsMap.UseResetMaps()
return s
}

View file

@ -1834,9 +1834,9 @@ var _ = Describe("Connection", func() {
conn.run()
}()
handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
Consistently(handshakeCtx).ShouldNot(BeClosed())
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
Eventually(handshakeCtx).Should(BeClosed())
// make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
@ -1865,7 +1865,7 @@ var _ = Describe("Connection", func() {
}()
handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
Consistently(handshakeCtx).ShouldNot(BeClosed())
close(finishHandshake)
var frames []*ackhandler.Frame
Eventually(func() []*ackhandler.Frame {
@ -1908,10 +1908,10 @@ var _ = Describe("Connection", func() {
conn.run()
}()
handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
Consistently(handshakeCtx).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any())
conn.closeLocal(errors.New("handshake error"))
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
Consistently(handshakeCtx).ShouldNot(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
})

View file

@ -266,7 +266,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
} else {
// wait for the handshake to complete
select {
case <-conn.HandshakeComplete().Done():
case <-conn.HandshakeComplete():
case <-req.Context().Done():
return nil, req.Context().Err()
}
@ -449,7 +449,7 @@ func (c *client) HandshakeComplete() bool {
return false
}
select {
case <-(*conn).HandshakeComplete().Done():
case <-(*conn).HandshakeComplete():
return true
default:
return false

View file

@ -26,10 +26,10 @@ import (
var _ = Describe("Client", func() {
var (
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeChan <-chan struct{} // a closed chan
)
BeforeEach(func() {
@ -43,9 +43,9 @@ var _ = Describe("Client", func() {
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
ch := make(chan struct{})
close(ch)
handshakeChan = ch
})
AfterEach(func() {
@ -202,7 +202,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
@ -328,7 +328,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
@ -434,7 +434,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
@ -696,7 +696,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("stream open error")
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
It("returns a response", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
@ -739,7 +739,7 @@ var _ = Describe("Client", func() {
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
@ -758,7 +758,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
@ -864,7 +864,7 @@ var _ = Describe("Client", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
errChan := make(chan error)
go func() {
@ -879,7 +879,7 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -910,7 +910,7 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
@ -930,7 +930,7 @@ var _ = Describe("Client", func() {
Context("gzip compression", func() {
BeforeEach(func() {
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
})
It("adds the gzip header to requests", func() {

View file

@ -40,7 +40,7 @@ var _ = Describe("early data", func() {
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
// make sure the Write finished before the handshake completed
Expect(conn.HandshakeComplete().Done()).ToNot(BeClosed())
Expect(conn.HandshakeComplete()).ToNot(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
}()
serverPort := ln.Addr().(*net.UDPAddr).Port

View file

@ -142,7 +142,7 @@ var _ = Describe("0-RTT", func() {
_, err = str.Write(testdata)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
<-conn.HandshakeComplete().Done()
<-conn.HandshakeComplete()
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
conn.CloseWithError(0, "")
@ -302,7 +302,7 @@ var _ = Describe("0-RTT", func() {
Expect(firstStr.Close()).To(Succeed())
// wait for the handshake to complete
Eventually(conn.HandshakeComplete().Done()).Should(BeClosed())
Eventually(conn.HandshakeComplete()).Should(BeClosed())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(PRData)

View file

@ -197,9 +197,10 @@ type EarlyConnection interface {
Connection
// HandshakeComplete blocks until the handshake completes (or fails).
// Data sent before completion of the handshake is encrypted with 1-RTT keys.
// Note that the client's identity hasn't been verified yet.
HandshakeComplete() context.Context
// For the client, data sent before completion of the handshake is encrypted with 0-RTT keys.
// For the serfer, data sent before completion of the handshake is encrypted with 1-RTT keys,
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
}

View file

@ -110,10 +110,10 @@ func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call {
}
// HandshakeComplete mocks base method.
func (m *MockEarlyConnection) HandshakeComplete() context.Context {
func (m *MockEarlyConnection) HandshakeComplete() <-chan struct{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandshakeComplete")
ret0, _ := ret[0].(context.Context)
ret0, _ := ret[0].(<-chan struct{})
return ret0
}

View file

@ -96,7 +96,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, c.dialErr
}
if req.Method != MethodGet0RTT {
<-c.conn.HandshakeComplete().Done()
<-c.conn.HandshakeComplete()
}
return c.doRequest(req)
}

View file

@ -123,10 +123,10 @@ func (mr *MockQuicConnMockRecorder) GetVersion() *gomock.Call {
}
// HandshakeComplete mocks base method.
func (m *MockQuicConn) HandshakeComplete() context.Context {
func (m *MockQuicConn) HandshakeComplete() <-chan struct{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandshakeComplete")
ret0, _ := ret[0].(context.Context)
ret0, _ := ret[0].(<-chan struct{})
return ret0
}

View file

@ -536,7 +536,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
} else {
// wait until the handshake is complete (or fails)
select {
case <-conn.HandshakeComplete().Done():
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return
}

View file

@ -298,7 +298,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() { close(run) })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
@ -494,7 +494,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() { close(run) })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
@ -546,7 +546,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
return conn
}
@ -626,9 +626,9 @@ var _ = Describe("Server", func() {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn.EXPECT().HandshakeComplete().Return(ctx)
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
@ -695,9 +695,9 @@ var _ = Describe("Server", func() {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run()
conn.EXPECT().Context().Return(ctx)
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn.EXPECT().HandshakeComplete().Return(ctx)
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
close(connCreated)
return conn
}
@ -972,7 +972,7 @@ var _ = Describe("Server", func() {
close(done)
}()
ctx, cancel := context.WithCancel(context.Background()) // handshake context
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -992,7 +992,7 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicConn {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().Context().Return(context.Background())
return conn
@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() {
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
})