diff --git a/handshake_client.go b/handshake_client.go index 553d2dd..5025657 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -478,7 +478,9 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( } if c.quic != nil { - c.quicResumeSession(session) + if c.quic.enableSessionEvents { + c.quicResumeSession(session) + } // For 0-RTT, the cipher suite has to match exactly, and we need to be // offering the same ALPN. diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 6744e71..db5e35d 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -900,7 +900,7 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { session.ageAdd = msg.ageAdd session.EarlyData = c.quic != nil && msg.maxEarlyData == 0xffffffff // RFC 9001, Section 4.6.1 session.ticket = msg.label - if c.quic != nil && c.quic.enableStoreSessionEvent { + if c.quic != nil && c.quic.enableSessionEvents { c.quicStoreSession(session) return nil } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index f24c267..503a732 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -377,7 +377,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { continue } - if c.quic != nil { + if c.quic != nil && c.quic.enableSessionEvents { if err := c.quicResumeSession(sessionState); err != nil { return err } diff --git a/quic.go b/quic.go index 8e722c6..9dd6168 100644 --- a/quic.go +++ b/quic.go @@ -50,12 +50,12 @@ type QUICConn struct { type QUICConfig struct { TLSConfig *Config - // EnableStoreSessionEvent may be set to true to enable the - // [QUICStoreSession] event for client connections. + // EnableSessionEvents may be set to true to enable the + // [QUICStoreSession] and [QUICResumeSession] events for client connections. // When this event is enabled, sessions are not automatically // stored in the client session cache. // The application should use [QUICConn.StoreSession] to store sessions. - EnableStoreSessionEvent bool + EnableSessionEvents bool } // A QUICEventKind is a type of operation on a QUIC connection. @@ -113,7 +113,7 @@ const ( // QUICStoreSession indicates that the server has provided state permitting // the client to resume the session. // [QUICEvent.SessionState] is set. - // The application should use [QUICConn.Store] session to store the [SessionState]. + // The application should use [QUICConn.StoreSession] session to store the [SessionState]. // The application may modify the [SessionState] before storing it. // This event only occurs on client connections. QUICStoreSession @@ -165,7 +165,7 @@ type quicState struct { transportParams []byte // to send to the peer - enableStoreSessionEvent bool + enableSessionEvents bool } // QUICClient returns a new TLS client side connection using QUICTransport as the @@ -186,9 +186,9 @@ func QUICServer(config *QUICConfig) *QUICConn { func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn { conn.quic = &quicState{ - signalc: make(chan struct{}), - blockedc: make(chan struct{}), - enableStoreSessionEvent: config.EnableStoreSessionEvent, + signalc: make(chan struct{}), + blockedc: make(chan struct{}), + enableSessionEvents: config.EnableSessionEvents, } conn.quic.events = conn.quic.eventArr[:0] return &QUICConn{ diff --git a/quic_test.go b/quic_test.go index 5a6f66e..1bb2e55 100644 --- a/quic_test.go +++ b/quic_test.go @@ -24,22 +24,22 @@ type testQUICConn struct { complete bool } -func newTestQUICClient(t *testing.T, config *Config) *testQUICConn { - q := &testQUICConn{t: t} - q.conn = QUICClient(&QUICConfig{ - TLSConfig: config, - }) +func newTestQUICClient(t *testing.T, config *QUICConfig) *testQUICConn { + q := &testQUICConn{ + t: t, + conn: QUICClient(config), + } t.Cleanup(func() { q.conn.Close() }) return q } -func newTestQUICServer(t *testing.T, config *Config) *testQUICConn { - q := &testQUICConn{t: t} - q.conn = QUICServer(&QUICConfig{ - TLSConfig: config, - }) +func newTestQUICServer(t *testing.T, config *QUICConfig) *testQUICConn { + q := &testQUICConn{ + t: t, + conn: QUICServer(config), + } t.Cleanup(func() { q.conn.Close() }) @@ -140,6 +140,11 @@ func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent return err } } + case QUICStoreSession: + if a != cli { + return errors.New("unexpected QUICStoreSession event received by server") + } + a.conn.StoreSession(e.SessionState) case QUICResumeSession: if a.onResumeSession != nil { a.onResumeSession(e.SessionState) @@ -154,8 +159,8 @@ func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent } func TestQUICConnection(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) @@ -196,13 +201,13 @@ func TestQUICConnection(t *testing.T) { } func TestQUICSessionResumption(t *testing.T) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - clientConfig.ServerName = "example.go.dev" + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.TLSConfig.ServerName = "example.go.dev" - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) @@ -228,13 +233,13 @@ func TestQUICSessionResumption(t *testing.T) { } func TestQUICFragmentaryData(t *testing.T) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - clientConfig.ServerName = "example.go.dev" + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.TLSConfig.ServerName = "example.go.dev" - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) @@ -260,8 +265,8 @@ func TestQUICFragmentaryData(t *testing.T) { func TestQUICPostHandshakeClientAuthentication(t *testing.T) { // RFC 9001, Section 4.4. - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) @@ -288,8 +293,8 @@ func TestQUICPostHandshakeClientAuthentication(t *testing.T) { func TestQUICPostHandshakeKeyUpdate(t *testing.T) { // RFC 9001, Section 6. - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) @@ -312,8 +317,8 @@ func TestQUICPostHandshakeKeyUpdate(t *testing.T) { } func TestQUICPostHandshakeMessageTooLarge(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) @@ -334,13 +339,13 @@ func TestQUICPostHandshakeMessageTooLarge(t *testing.T) { } func TestQUICHandshakeError(t *testing.T) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.InsecureSkipVerify = false - clientConfig.ServerName = "name" + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.InsecureSkipVerify = false + clientConfig.TLSConfig.ServerName = "name" - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) @@ -360,9 +365,9 @@ func TestQUICHandshakeError(t *testing.T) { // and that it reports the application protocol as soon as it has been // negotiated. func TestQUICConnectionState(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 - config.NextProtos = []string{"h3"} + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 + config.TLSConfig.NextProtos = []string{"h3"} cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) @@ -391,10 +396,10 @@ func TestQUICStartContextPropagation(t *testing.T) { const key = "key" const value = "value" ctx := context.WithValue(context.Background(), key, value) - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 calls := 0 - config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) { + config.TLSConfig.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) { calls++ got, _ := info.Context().Value(key).(string) if got != value { @@ -415,13 +420,13 @@ func TestQUICStartContextPropagation(t *testing.T) { } func TestQUICDelayedTransportParameters(t *testing.T) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - clientConfig.ServerName = "example.go.dev" + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.TLSConfig.ServerName = "example.go.dev" - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.TLSConfig.MinVersion = VersionTLS13 cliParams := "client params" srvParams := "server params" @@ -449,8 +454,8 @@ func TestQUICDelayedTransportParameters(t *testing.T) { } func TestQUICEmptyTransportParameters(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) @@ -475,8 +480,8 @@ func TestQUICEmptyTransportParameters(t *testing.T) { } func TestQUICCanceledWaitingForData(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) cli.conn.Start(context.Background()) @@ -489,8 +494,8 @@ func TestQUICCanceledWaitingForData(t *testing.T) { } func TestQUICCanceledWaitingForTransportParams(t *testing.T) { - config := testConfig.Clone() - config.MinVersion = VersionTLS13 + config := &QUICConfig{TLSConfig: testConfig.Clone()} + config.TLSConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.Start(context.Background()) for cli.conn.NextEvent().Kind != QUICTransportParametersRequired { @@ -502,15 +507,15 @@ func TestQUICCanceledWaitingForTransportParams(t *testing.T) { } func TestQUICEarlyData(t *testing.T) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - clientConfig.ServerName = "example.go.dev" - clientConfig.NextProtos = []string{"h3"} + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.TLSConfig.ServerName = "example.go.dev" + clientConfig.TLSConfig.NextProtos = []string{"h3"} - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 - serverConfig.NextProtos = []string{"h3"} + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.TLSConfig.MinVersion = VersionTLS13 + serverConfig.TLSConfig.NextProtos = []string{"h3"} cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) @@ -528,7 +533,14 @@ func TestQUICEarlyData(t *testing.T) { cli2.conn.SetTransportParameters(nil) srv2 := newTestQUICServer(t, serverConfig) srv2.conn.SetTransportParameters(nil) - if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil { + onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { + switch e.Kind { + case QUICStoreSession, QUICResumeSession: + t.Errorf("with EnableSessionEvents=false, got unexpected event %v", e.Kind) + } + return false + } + if err := runTestQUICConnection(context.Background(), cli2, srv2, onEvent); err != nil { t.Fatalf("error during second connection handshake: %v", err) } if !cli2.conn.ConnectionState().DidResume { @@ -557,15 +569,17 @@ func TestQUICEarlyDataDeclined(t *testing.T) { } func testQUICEarlyDataDeclined(t *testing.T, server bool) { - clientConfig := testConfig.Clone() - clientConfig.MinVersion = VersionTLS13 - clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - clientConfig.ServerName = "example.go.dev" - clientConfig.NextProtos = []string{"h3"} + clientConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + clientConfig.EnableSessionEvents = true + clientConfig.TLSConfig.MinVersion = VersionTLS13 + clientConfig.TLSConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.TLSConfig.ServerName = "example.go.dev" + clientConfig.TLSConfig.NextProtos = []string{"h3"} - serverConfig := testConfig.Clone() - serverConfig.MinVersion = VersionTLS13 - serverConfig.NextProtos = []string{"h3"} + serverConfig := &QUICConfig{TLSConfig: testConfig.Clone()} + serverConfig.EnableSessionEvents = true + serverConfig.TLSConfig.MinVersion = VersionTLS13 + serverConfig.TLSConfig.NextProtos = []string{"h3"} cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil)