mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
handshake: use new crypto/tls 0-RTT API (#4953)
* handshake: simplify method signature of cryptoSetup.handleEvent * use the new crypto/tls 0-RTT API
This commit is contained in:
parent
b32f1fa0e4
commit
bf28da8346
10 changed files with 182 additions and 380 deletions
|
@ -12,7 +12,6 @@ import (
|
|||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
@ -89,12 +88,13 @@ func NewCryptoSetupClient(
|
|||
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
|
||||
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
|
||||
cs.tlsConf = tlsConf
|
||||
cs.allow0RTT = enable0RTT
|
||||
|
||||
cs.conn = tls.QUICClient(quicConf)
|
||||
cs.conn = tls.QUICClient(&tls.QUICConfig{
|
||||
TLSConfig: tlsConf,
|
||||
EnableSessionEvents: true,
|
||||
})
|
||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||
|
||||
return cs
|
||||
|
@ -123,9 +123,13 @@ func NewCryptoSetupServer(
|
|||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
|
||||
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
|
||||
tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr)
|
||||
|
||||
cs.tlsConf = tlsConf
|
||||
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
|
||||
cs.conn = tls.QUICServer(&tls.QUICConfig{
|
||||
TLSConfig: tlsConf,
|
||||
EnableSessionEvents: true,
|
||||
})
|
||||
return cs
|
||||
}
|
||||
|
||||
|
@ -178,11 +182,10 @@ func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
|
|||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
if err := h.handleEvent(ev); err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
if done {
|
||||
if ev.Kind == tls.QUICNoEvent {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -213,53 +216,78 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
|
||||
if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
if err := h.handleEvent(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
if ev.Kind == tls.QUICNoEvent {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
|
||||
//nolint:exhaustive
|
||||
// Go 1.23 added new 0-RTT events, see https://github.com/quic-go/quic-go/issues/4272.
|
||||
// We will start using these events when dropping support for Go 1.22.
|
||||
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) {
|
||||
switch ev.Kind {
|
||||
case tls.QUICNoEvent:
|
||||
return true, nil
|
||||
return nil
|
||||
case tls.QUICSetReadSecret:
|
||||
h.setReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICSetWriteSecret:
|
||||
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICTransportParameters:
|
||||
return false, h.handleTransportParameters(ev.Data)
|
||||
return h.handleTransportParameters(ev.Data)
|
||||
case tls.QUICTransportParametersRequired:
|
||||
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICRejectedEarlyData:
|
||||
h.rejected0RTT()
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICWriteData:
|
||||
h.writeRecord(ev.Level, ev.Data)
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICHandshakeDone:
|
||||
h.handshakeComplete()
|
||||
return false, nil
|
||||
return nil
|
||||
case tls.QUICStoreSession:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server")
|
||||
}
|
||||
ev.SessionState.Extra = append(
|
||||
ev.SessionState.Extra,
|
||||
addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)),
|
||||
)
|
||||
return h.conn.StoreSession(ev.SessionState)
|
||||
case tls.QUICResumeSession:
|
||||
var allowEarlyData bool
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
// for clients, this event occurs when a session ticket is selected
|
||||
allowEarlyData = h.handleDataFromSessionState(
|
||||
findSessionStateExtraData(ev.SessionState.Extra),
|
||||
ev.SessionState.EarlyData,
|
||||
)
|
||||
case protocol.PerspectiveServer:
|
||||
// for servers, this event occurs when receiving the client's session ticket
|
||||
allowEarlyData = h.handleSessionTicket(
|
||||
findSessionStateExtraData(ev.SessionState.Extra),
|
||||
ev.SessionState.EarlyData,
|
||||
)
|
||||
}
|
||||
if ev.SessionState.EarlyData {
|
||||
ev.SessionState.EarlyData = allowEarlyData
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// Unknown events should be ignored.
|
||||
// crypto/tls will ensure that this is safe to do.
|
||||
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,7 +378,10 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
|||
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
||||
// It is only valid for the server.
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
|
||||
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
|
||||
EarlyData: h.allow0RTT,
|
||||
Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())},
|
||||
}); err != nil {
|
||||
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
|
||||
// We can't check h.tlsConfig here, since the actual config might have been obtained from
|
||||
// the GetConfigForClient callback.
|
||||
|
@ -376,9 +407,9 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
|||
// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT.
|
||||
// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT:
|
||||
// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT.
|
||||
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
|
||||
func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) {
|
||||
var t sessionTicket
|
||||
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
|
||||
if err := t.Unmarshal(data, using0RTT); err != nil {
|
||||
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
|
@ -446,7 +477,7 @@ func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
|
|||
}
|
||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -497,7 +528,7 @@ func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
|
|||
panic("unexpected write encryption level")
|
||||
}
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package qtls
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
@ -52,3 +53,20 @@ func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
|
|||
t.RTT = time.Duration(rtt) * time.Microsecond
|
||||
return nil
|
||||
}
|
||||
|
||||
const extraPrefix = "quic-go1"
|
||||
|
||||
func addSessionStateExtraPrefix(b []byte) []byte {
|
||||
return append([]byte(extraPrefix), b...)
|
||||
}
|
||||
|
||||
func findSessionStateExtraData(extras [][]byte) []byte {
|
||||
prefix := []byte(extraPrefix)
|
||||
for _, extra := range extras {
|
||||
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
|
||||
continue
|
||||
}
|
||||
return extra[len(prefix):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
39
internal/handshake/tls_config.go
Normal file
39
internal/handshake/tls_config.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config {
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
|
||||
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||
// that allows the caller to get the local and the remote address.
|
||||
if conf.GetConfigForClient != nil {
|
||||
gcfc := conf.GetConfigForClient
|
||||
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
c, err := gcfc(info)
|
||||
if c != nil {
|
||||
// we're returning a tls.Config here, so we need to apply this recursively
|
||||
c = setupConfigForServer(c, localAddr, remoteAddr)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
if conf.GetCertificate != nil {
|
||||
gc := conf.GetCertificate
|
||||
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
return gc(info)
|
||||
}
|
||||
}
|
||||
return conf
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package qtls
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
@ -6,52 +6,15 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncryptionLevelConversion(t *testing.T) {
|
||||
testCases := []struct {
|
||||
quicLevel protocol.EncryptionLevel
|
||||
tlsLevel tls.QUICEncryptionLevel
|
||||
}{
|
||||
{protocol.EncryptionInitial, tls.QUICEncryptionLevelInitial},
|
||||
{protocol.EncryptionHandshake, tls.QUICEncryptionLevelHandshake},
|
||||
{protocol.Encryption1RTT, tls.QUICEncryptionLevelApplication},
|
||||
{protocol.Encryption0RTT, tls.QUICEncryptionLevelEarly},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.quicLevel.String(), func(t *testing.T) {
|
||||
// conversion from QUIC to TLS encryption level
|
||||
require.Equal(t, tc.tlsLevel, ToTLSEncryptionLevel(tc.quicLevel))
|
||||
// conversion from TLS to QUIC encryption level
|
||||
require.Equal(t, tc.quicLevel, FromTLSEncryptionLevel(tc.tlsLevel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupSessionCache(t *testing.T) {
|
||||
// Test with a session cache present
|
||||
csc := tls.NewLRUClientSessionCache(1)
|
||||
confWithCache := &tls.QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}}
|
||||
SetupConfigForClient(confWithCache, nil, nil)
|
||||
require.NotNil(t, confWithCache.TLSConfig.ClientSessionCache)
|
||||
require.NotEqual(t, csc, confWithCache.TLSConfig.ClientSessionCache)
|
||||
|
||||
// Test without a session cache
|
||||
confWithoutCache := &tls.QUICConfig{TLSConfig: &tls.Config{}}
|
||||
SetupConfigForClient(confWithoutCache, nil, nil)
|
||||
require.Nil(t, confWithoutCache.TLSConfig.ClientSessionCache)
|
||||
}
|
||||
|
||||
func TestMinimumTLSVersion(t *testing.T) {
|
||||
local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
|
||||
remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
|
||||
orig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
conf := SetupConfigForServer(orig, local, remote, nil, nil)
|
||||
conf := setupConfigForServer(orig, local, remote)
|
||||
require.EqualValues(t, tls.VersionTLS13, conf.MinVersion)
|
||||
// check that the original config wasn't modified
|
||||
require.EqualValues(t, tls.VersionTLS12, orig.MinVersion)
|
||||
|
@ -69,7 +32,7 @@ func TestServerConfigGetCertificate(t *testing.T) {
|
|||
return &tls.Certificate{}, nil
|
||||
},
|
||||
}
|
||||
conf := SetupConfigForServer(tlsConf, local, remote, nil, nil)
|
||||
conf := setupConfigForServer(tlsConf, local, remote)
|
||||
_, err := conf.GetCertificate(&tls.ClientHelloInfo{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, local, localAddr)
|
||||
|
@ -81,7 +44,7 @@ func TestServerConfigGetConfigForClient(t *testing.T) {
|
|||
remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
|
||||
var localAddr, remoteAddr net.Addr
|
||||
tlsConf := SetupConfigForServer(
|
||||
tlsConf := setupConfigForServer(
|
||||
&tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
localAddr = info.Conn.LocalAddr()
|
||||
|
@ -91,8 +54,6 @@ func TestServerConfigGetConfigForClient(t *testing.T) {
|
|||
},
|
||||
local,
|
||||
remote,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
require.NoError(t, err)
|
||||
|
@ -121,7 +82,7 @@ func TestServerConfigGetConfigForClientRecursively(t *testing.T) {
|
|||
innerConf.GetCertificate = getCert
|
||||
return innerConf, nil
|
||||
}
|
||||
tlsConf = SetupConfigForServer(tlsConf, local, remote, nil, nil)
|
||||
tlsConf = setupConfigForServer(tlsConf, local, remote)
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conf)
|
|
@ -1,5 +1,10 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EncryptionLevel is the encryption level
|
||||
// Default value is Unencrypted
|
||||
type EncryptionLevel uint8
|
||||
|
@ -28,3 +33,33 @@ func (e EncryptionLevel) String() string {
|
|||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case EncryptionInitial:
|
||||
return tls.QUICEncryptionLevelInitial
|
||||
case EncryptionHandshake:
|
||||
return tls.QUICEncryptionLevelHandshake
|
||||
case Encryption1RTT:
|
||||
return tls.QUICEncryptionLevelApplication
|
||||
case Encryption0RTT:
|
||||
return tls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel {
|
||||
switch e {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
return EncryptionInitial
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
return EncryptionHandshake
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
return Encryption1RTT
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
return Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -10,6 +11,27 @@ func TestEncryptionLevelNonZeroValue(t *testing.T) {
|
|||
require.NotZero(t, EncryptionInitial*EncryptionHandshake*Encryption0RTT*Encryption1RTT)
|
||||
}
|
||||
|
||||
func TestEncryptionLevelConversion(t *testing.T) {
|
||||
testCases := []struct {
|
||||
quicLevel EncryptionLevel
|
||||
tlsLevel tls.QUICEncryptionLevel
|
||||
}{
|
||||
{EncryptionInitial, tls.QUICEncryptionLevelInitial},
|
||||
{EncryptionHandshake, tls.QUICEncryptionLevelHandshake},
|
||||
{Encryption1RTT, tls.QUICEncryptionLevelApplication},
|
||||
{Encryption0RTT, tls.QUICEncryptionLevelEarly},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.quicLevel.String(), func(t *testing.T) {
|
||||
// conversion from QUIC to TLS encryption level
|
||||
require.Equal(t, tc.tlsLevel, tc.quicLevel.ToTLSEncryptionLevel())
|
||||
// conversion from TLS to QUIC encryption level
|
||||
require.Equal(t, tc.quicLevel, FromTLSEncryptionLevel(tc.tlsLevel))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptionLevelStringRepresentation(t *testing.T) {
|
||||
require.Equal(t, "Initial", EncryptionInitial.String())
|
||||
require.Equal(t, "Handshake", EncryptionHandshake.String())
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
mx sync.Mutex
|
||||
getData func(earlyData bool) []byte
|
||||
setData func(data []byte, earlyData bool) (allowEarlyData bool)
|
||||
wrapped tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
if cs == nil {
|
||||
c.wrapped.Put(key, nil)
|
||||
return
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil || state == nil {
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData)))
|
||||
newCS, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error. Just save the original state.
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
c.wrapped.Put(key, newCS)
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
cs, ok := c.wrapped.Get(key)
|
||||
if !ok || cs == nil {
|
||||
return cs, ok
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
// restore QUIC transport parameters and RTT stored in state.Extra
|
||||
if extra := findExtraData(state.Extra); extra != nil {
|
||||
earlyData := c.setData(extra, state.EarlyData)
|
||||
if state.EarlyData {
|
||||
state.EarlyData = earlyData
|
||||
}
|
||||
}
|
||||
session, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/testdata"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientSessionCacheAddAndRestoreData(t *testing.T) {
|
||||
ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = conn.Read(make([]byte, 10))
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
restored := make(chan []byte, 1)
|
||||
clientConf := &tls.Config{
|
||||
RootCAs: testdata.GetRootCA(),
|
||||
ClientSessionCache: &clientSessionCache{
|
||||
wrapped: tls.NewLRUClientSessionCache(10),
|
||||
getData: func(bool) []byte { return []byte("session") },
|
||||
setData: func(data []byte, earlyData bool) bool {
|
||||
require.False(t, earlyData) // running on top of TCP, we can only test non-0-RTT here
|
||||
restored <- data
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := tls.Dial(
|
||||
"tcp4",
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
|
||||
clientConf,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
require.False(t, conn.ConnectionState().DidResume)
|
||||
require.Len(t, restored, 0)
|
||||
_, err = conn.Read(make([]byte, 10))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
// make sure the cache can deal with nonsensical inputs
|
||||
clientConf.ClientSessionCache.Put("foo", nil)
|
||||
clientConf.ClientSessionCache.Put("bar", &tls.ClientSessionState{})
|
||||
|
||||
conn, err = tls.Dial(
|
||||
"tcp4",
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
|
||||
clientConf,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
require.True(t, conn.ConnectionState().DidResume)
|
||||
var restoredData []byte
|
||||
select {
|
||||
case restoredData = <-restored:
|
||||
default:
|
||||
t.Fatal("no data restored")
|
||||
}
|
||||
require.Equal(t, []byte("session"), restoredData)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
require.NoError(t, ln.Close())
|
||||
<-done
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func SetupConfigForServer(
|
||||
conf *tls.Config,
|
||||
localAddr, remoteAddr net.Addr,
|
||||
getData func() []byte,
|
||||
handleSessionTicket func([]byte, bool) bool,
|
||||
) *tls.Config {
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
|
||||
// add callbacks to save transport parameters into the session ticket
|
||||
origWrapSession := conf.WrapSession
|
||||
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
|
||||
// Add QUIC session ticket
|
||||
state.Extra = append(state.Extra, addExtraPrefix(getData()))
|
||||
|
||||
if origWrapSession != nil {
|
||||
return origWrapSession(cs, state)
|
||||
}
|
||||
b, err := conf.EncryptTicket(cs, state)
|
||||
return b, err
|
||||
}
|
||||
origUnwrapSession := conf.UnwrapSession
|
||||
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
|
||||
// However, using 0-RTT is only possible with the first session ticket.
|
||||
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
|
||||
var unwrapCount int
|
||||
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
|
||||
unwrapCount++
|
||||
var state *tls.SessionState
|
||||
var err error
|
||||
if origUnwrapSession != nil {
|
||||
state, err = origUnwrapSession(identity, connState)
|
||||
} else {
|
||||
state, err = conf.DecryptTicket(identity, connState)
|
||||
}
|
||||
if err != nil || state == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extra := findExtraData(state.Extra)
|
||||
if extra != nil {
|
||||
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
|
||||
} else {
|
||||
state.EarlyData = false
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||
// that allows the caller to get the local and the remote address.
|
||||
if conf.GetConfigForClient != nil {
|
||||
gcfc := conf.GetConfigForClient
|
||||
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
c, err := gcfc(info)
|
||||
if c != nil {
|
||||
// We're returning a tls.Config here, so we need to apply this recursively.
|
||||
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
if conf.GetCertificate != nil {
|
||||
gc := conf.GetCertificate
|
||||
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
return gc(info)
|
||||
}
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
func SetupConfigForClient(
|
||||
qconf *tls.QUICConfig,
|
||||
getData func(earlyData bool) []byte,
|
||||
setData func(data []byte, earlyData bool) (allowEarlyData bool),
|
||||
) {
|
||||
conf := qconf.TLSConfig
|
||||
if conf.ClientSessionCache != nil {
|
||||
origCache := conf.ClientSessionCache
|
||||
conf.ClientSessionCache = &clientSessionCache{
|
||||
wrapped: origCache,
|
||||
getData: getData,
|
||||
setData: setData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case protocol.EncryptionInitial:
|
||||
return tls.QUICEncryptionLevelInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return tls.QUICEncryptionLevelHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
return tls.QUICEncryptionLevelApplication
|
||||
case protocol.Encryption0RTT:
|
||||
return tls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
||||
switch e {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
return protocol.Encryption1RTT
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
return protocol.Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
const extraPrefix = "quic-go1"
|
||||
|
||||
func addExtraPrefix(b []byte) []byte {
|
||||
return append([]byte(extraPrefix), b...)
|
||||
}
|
||||
|
||||
func findExtraData(extras [][]byte) []byte {
|
||||
prefix := []byte(extraPrefix)
|
||||
for _, extra := range extras {
|
||||
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
|
||||
continue
|
||||
}
|
||||
return extra[len(prefix):]
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue