only enable 0-RTT when using ListenEarly and DialEarly

This commit is contained in:
Marten Seemann 2020-01-17 10:47:50 +07:00
parent 39efdfe695
commit eeba3951ae
7 changed files with 57 additions and 30 deletions

View file

@ -249,7 +249,26 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
func (c *client) dial(ctx context.Context) error { func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.createNewTLSSession(c.version)
c.mutex.Lock()
c.session = newClientSession(
c.conn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.initialVersion,
c.use0RTT,
c.logger,
c.version,
)
c.mutex.Unlock()
// It's not possible to use the stateless reset token for the client's (first) connection ID,
// since there's no way to securely communicate it to the server.
c.packetHandlers.Add(c.srcConnID, c)
err := c.establishSecureConnection(ctx) err := c.establishSecureConnection(ctx)
if err == errCloseForRecreating { if err == errCloseForRecreating {
return c.dial(ctx) return c.dial(ctx)
@ -354,26 +373,6 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
c.initialPacketNumber = c.session.closeForRecreating() c.initialPacketNumber = c.session.closeForRecreating()
} }
func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
c.mutex.Lock()
c.session = newClientSession(
c.conn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.initialVersion,
c.logger,
c.version,
)
c.mutex.Unlock()
// It's not possible to use the stateless reset token for the client's (first) connection ID,
// since there's no way to securely communicate it to the server.
c.packetHandlers.Add(c.srcConnID, c)
}
func (c *client) Close() error { func (c *client) Close() error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()

View file

@ -38,6 +38,7 @@ var _ = Describe("Client", func() {
tlsConf *tls.Config, tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
enable0RTT bool,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) quicSession ) quicSession
@ -140,6 +141,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -170,6 +172,7 @@ var _ = Describe("Client", func() {
tlsConf *tls.Config, tlsConf *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -200,6 +203,7 @@ var _ = Describe("Client", func() {
tlsConf *tls.Config, tlsConf *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -235,9 +239,11 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
enable0RTT bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
Expect(enable0RTT).To(BeFalse())
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) }) sess.EXPECT().run().Do(func() { close(run) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -273,9 +279,11 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
enable0RTT bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
Expect(enable0RTT).To(BeTrue())
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { <-done }) sess.EXPECT().run().Do(func() { <-done })
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
@ -316,6 +324,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -356,6 +365,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -404,6 +414,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -523,6 +534,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, /* initial version */ _ protocol.VersionNumber, /* initial version */
_ bool,
_ utils.Logger, _ utils.Logger,
versionP protocol.VersionNumber, versionP protocol.VersionNumber,
) quicSession { ) quicSession {
@ -571,6 +583,7 @@ var _ = Describe("Client", func() {
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber, _ protocol.PacketNumber,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {

View file

@ -44,7 +44,7 @@ var _ = Describe("0-RTT", func() {
return proxy, &num0RTTPackets return proxy, &num0RTTPackets
} }
dialAndReceiveSessionTicket := func(ln quic.Listener, proxyPort int) *tls.Config { dialAndReceiveSessionTicket := func(ln quic.EarlyListener, proxyPort int) *tls.Config {
// dial the first session in order to receive a session ticket // dial the first session in order to receive a session ticket
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -68,7 +68,7 @@ var _ = Describe("0-RTT", func() {
return clientConf return clientConf
} }
transfer0RTTData := func(ln quic.Listener, proxyPort int, clientConf *tls.Config, testdata []byte) { transfer0RTTData := func(ln quic.EarlyListener, proxyPort int, clientConf *tls.Config, testdata []byte) {
// now dial the second session, and use 0-RTT to send some data // now dial the second session, and use 0-RTT to send some data
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -98,7 +98,7 @@ var _ = Describe("0-RTT", func() {
} }
It("transfers 0-RTT data", func() { It("transfers 0-RTT data", func() {
ln, err := quic.ListenAddr( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
&quic.Config{ &quic.Config{
@ -122,7 +122,7 @@ var _ = Describe("0-RTT", func() {
// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
It("waits until a session until the handshake is done", func() { It("waits until a session until the handshake is done", func() {
ln, err := quic.ListenAddr( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
&quic.Config{ &quic.Config{
@ -199,7 +199,7 @@ var _ = Describe("0-RTT", func() {
num0RTTDropped uint32 num0RTTDropped uint32
) )
ln, err := quic.ListenAddr( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
&quic.Config{ &quic.Config{
@ -253,7 +253,7 @@ var _ = Describe("0-RTT", func() {
var firstConnID, secondConnID protocol.ConnectionID var firstConnID, secondConnID protocol.ConnectionID
var firstCounter, secondCounter int var firstCounter, secondCounter int
ln, err := quic.ListenAddr( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}}, &quic.Config{Versions: []protocol.VersionNumber{version}},

View file

@ -72,7 +72,7 @@ type baseServer struct {
sessionHandler packetHandlerManager sessionHandler packetHandlerManager
// set as a member, so they can be set in the tests // set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, bool /* enable 0-RTT */, utils.Logger, protocol.VersionNumber) quicSession
serverError error serverError error
errorChan chan struct{} errorChan chan struct{}
@ -450,6 +450,7 @@ func (s *baseServer) createNewSession(
s.config, s.config,
s.tlsConf, s.tlsConf,
s.tokenGenerator, s.tokenGenerator,
s.acceptEarlySessions,
s.logger, s.logger,
version, version,
) )

View file

@ -331,9 +331,11 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
enable0RTT bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
Expect(enable0RTT).To(BeFalse())
Expect(origConnID).To(Equal(hdr.DestConnectionID)) Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID // make sure we're using a server-generated connection ID
@ -381,6 +383,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -409,6 +412,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -469,6 +473,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -572,6 +577,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -624,9 +630,11 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
enable0RTT bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
Expect(enable0RTT).To(BeTrue())
sess.EXPECT().run().Do(func() {}) sess.EXPECT().run().Do(func() {})
sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
@ -653,6 +661,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {
@ -709,6 +718,7 @@ var _ = Describe("Server", func() {
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ bool,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) quicSession { ) quicSession {

View file

@ -202,6 +202,7 @@ var newSession = func(
conf *Config, conf *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator, tokenGenerator *handshake.TokenGenerator,
enable0RTT bool,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) quicSession { ) quicSession {
@ -274,7 +275,7 @@ var newSession = func(
}, },
}, },
tlsConf, tlsConf,
true, // TODO: make 0-RTT support configurable enable0RTT,
s.rttStats, s.rttStats,
logger, logger,
) )
@ -308,6 +309,7 @@ var newClientSession = func(
tlsConf *tls.Config, tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
enable0RTT bool,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) quicSession { ) quicSession {
@ -371,7 +373,7 @@ var newClientSession = func(
onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
}, },
tlsConf, tlsConf,
true, // TODO: make 0-RTT support configurable enable0RTT,
s.rttStats, s.rttStats,
logger, logger,
) )

View file

@ -122,6 +122,7 @@ var _ = Describe("Session", func() {
populateServerConfig(&Config{}), populateServerConfig(&Config{}),
nil, // tls.Config nil, // tls.Config
tokenGenerator, tokenGenerator,
false,
utils.DefaultLogger, utils.DefaultLogger,
protocol.VersionTLS, protocol.VersionTLS,
).(*session) ).(*session)
@ -1658,6 +1659,7 @@ var _ = Describe("Client Session", func() {
tlsConf, tlsConf,
42, // initial packet number 42, // initial packet number
protocol.VersionTLS, protocol.VersionTLS,
false,
utils.DefaultLogger, utils.DefaultLogger,
protocol.VersionTLS, protocol.VersionTLS,
).(*session) ).(*session)