diff --git a/client.go b/client.go index 2431d9be..f66ef510 100644 --- a/client.go +++ b/client.go @@ -249,7 +249,26 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config { func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - c.createNewTLSSession(c.version) + + c.mutex.Lock() + c.session = newClientSession( + c.conn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.initialVersion, + c.use0RTT, + c.logger, + c.version, + ) + c.mutex.Unlock() + // It's not possible to use the stateless reset token for the client's (first) connection ID, + // since there's no way to securely communicate it to the server. + c.packetHandlers.Add(c.srcConnID, c) + err := c.establishSecureConnection(ctx) if err == errCloseForRecreating { return c.dial(ctx) @@ -354,26 +373,6 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { c.initialPacketNumber = c.session.closeForRecreating() } -func (c *client) createNewTLSSession(_ protocol.VersionNumber) { - c.mutex.Lock() - c.session = newClientSession( - c.conn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.initialVersion, - c.logger, - c.version, - ) - c.mutex.Unlock() - // It's not possible to use the stateless reset token for the client's (first) connection ID, - // since there's no way to securely communicate it to the server. - c.packetHandlers.Add(c.srcConnID, c) -} - func (c *client) Close() error { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/client_test.go b/client_test.go index 709bda2c..42bad682 100644 --- a/client_test.go +++ b/client_test.go @@ -38,6 +38,7 @@ var _ = Describe("Client", func() { tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, initialVersion protocol.VersionNumber, + enable0RTT bool, logger utils.Logger, v protocol.VersionNumber, ) quicSession @@ -140,6 +141,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -170,6 +172,7 @@ var _ = Describe("Client", func() { tlsConf *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -200,6 +203,7 @@ var _ = Describe("Client", func() { tlsConf *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -235,9 +239,11 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + enable0RTT bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { + Expect(enable0RTT).To(BeFalse()) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { close(run) }) ctx, cancel := context.WithCancel(context.Background()) @@ -273,9 +279,11 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + enable0RTT bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { + Expect(enable0RTT).To(BeTrue()) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { <-done }) sess.EXPECT().HandshakeComplete().Return(context.Background()) @@ -316,6 +324,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -356,6 +365,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -404,6 +414,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -523,6 +534,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, /* initial version */ + _ bool, _ utils.Logger, versionP protocol.VersionNumber, ) quicSession { @@ -571,6 +583,7 @@ var _ = Describe("Client", func() { _ *tls.Config, _ protocol.PacketNumber, _ protocol.VersionNumber, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index a62b036e..e4b38a61 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -44,7 +44,7 @@ var _ = Describe("0-RTT", func() { return proxy, &num0RTTPackets } - dialAndReceiveSessionTicket := func(ln quic.Listener, proxyPort int) *tls.Config { + dialAndReceiveSessionTicket := func(ln quic.EarlyListener, proxyPort int) *tls.Config { // dial the first session in order to receive a session ticket go func() { defer GinkgoRecover() @@ -68,7 +68,7 @@ var _ = Describe("0-RTT", func() { return clientConf } - transfer0RTTData := func(ln quic.Listener, proxyPort int, clientConf *tls.Config, testdata []byte) { + transfer0RTTData := func(ln quic.EarlyListener, proxyPort int, clientConf *tls.Config, testdata []byte) { // now dial the second session, and use 0-RTT to send some data done := make(chan struct{}) go func() { @@ -98,7 +98,7 @@ var _ = Describe("0-RTT", func() { } It("transfers 0-RTT data", func() { - ln, err := quic.ListenAddr( + ln, err := quic.ListenAddrEarly( "localhost:0", getTLSConfig(), &quic.Config{ @@ -122,7 +122,7 @@ var _ = Describe("0-RTT", func() { // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. It("waits until a session until the handshake is done", func() { - ln, err := quic.ListenAddr( + ln, err := quic.ListenAddrEarly( "localhost:0", getTLSConfig(), &quic.Config{ @@ -199,7 +199,7 @@ var _ = Describe("0-RTT", func() { num0RTTDropped uint32 ) - ln, err := quic.ListenAddr( + ln, err := quic.ListenAddrEarly( "localhost:0", getTLSConfig(), &quic.Config{ @@ -253,7 +253,7 @@ var _ = Describe("0-RTT", func() { var firstConnID, secondConnID protocol.ConnectionID var firstCounter, secondCounter int - ln, err := quic.ListenAddr( + ln, err := quic.ListenAddrEarly( "localhost:0", getTLSConfig(), &quic.Config{Versions: []protocol.VersionNumber{version}}, diff --git a/server.go b/server.go index 7638d037..1b8d0877 100644 --- a/server.go +++ b/server.go @@ -72,7 +72,7 @@ type baseServer struct { sessionHandler packetHandlerManager // set as a member, so they can be set in the tests - newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession + newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, bool /* enable 0-RTT */, utils.Logger, protocol.VersionNumber) quicSession serverError error errorChan chan struct{} @@ -450,6 +450,7 @@ func (s *baseServer) createNewSession( s.config, s.tlsConf, s.tokenGenerator, + s.acceptEarlySessions, s.logger, version, ) diff --git a/server_test.go b/server_test.go index 5858fe9b..7d360c8a 100644 --- a/server_test.go +++ b/server_test.go @@ -331,9 +331,11 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + enable0RTT bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { + Expect(enable0RTT).To(BeFalse()) Expect(origConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID @@ -381,6 +383,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -409,6 +412,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -469,6 +473,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -572,6 +577,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -624,9 +630,11 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + enable0RTT bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { + Expect(enable0RTT).To(BeTrue()) sess.EXPECT().run().Do(func() {}) sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().Context().Return(context.Background()) @@ -653,6 +661,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { @@ -709,6 +718,7 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, + _ bool, _ utils.Logger, _ protocol.VersionNumber, ) quicSession { diff --git a/session.go b/session.go index e34c8973..68d304a0 100644 --- a/session.go +++ b/session.go @@ -202,6 +202,7 @@ var newSession = func( conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, + enable0RTT bool, logger utils.Logger, v protocol.VersionNumber, ) quicSession { @@ -274,7 +275,7 @@ var newSession = func( }, }, tlsConf, - true, // TODO: make 0-RTT support configurable + enable0RTT, s.rttStats, logger, ) @@ -308,6 +309,7 @@ var newClientSession = func( tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, initialVersion protocol.VersionNumber, + enable0RTT bool, logger utils.Logger, v protocol.VersionNumber, ) quicSession { @@ -371,7 +373,7 @@ var newClientSession = func( onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, - true, // TODO: make 0-RTT support configurable + enable0RTT, s.rttStats, logger, ) diff --git a/session_test.go b/session_test.go index 9a842e1a..0402e414 100644 --- a/session_test.go +++ b/session_test.go @@ -122,6 +122,7 @@ var _ = Describe("Session", func() { populateServerConfig(&Config{}), nil, // tls.Config tokenGenerator, + false, utils.DefaultLogger, protocol.VersionTLS, ).(*session) @@ -1658,6 +1659,7 @@ var _ = Describe("Client Session", func() { tlsConf, 42, // initial packet number protocol.VersionTLS, + false, utils.DefaultLogger, protocol.VersionTLS, ).(*session)