mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
crypto/tls: support QUIC as a transport
Add a QUICConn type for use by QUIC implementations. A QUICConn provides unencrypted handshake bytes and connection secrets to the QUIC layer, and receives handshake bytes. For #44886 Change-Id: I859dda4cc6d466a1df2fb863a69d3a2a069110d5 Reviewed-on: https://go-review.googlesource.com/c/go/+/493655 TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org> Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Matthew Dempsky <mdempsky@google.com> Reviewed-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
32e60edd6d
commit
b7691e8126
11 changed files with 1077 additions and 50 deletions
10
alert.go
10
alert.go
|
@ -6,6 +6,16 @@ package tls
|
||||||
|
|
||||||
import "strconv"
|
import "strconv"
|
||||||
|
|
||||||
|
// An AlertError is a TLS alert.
|
||||||
|
//
|
||||||
|
// When using a QUIC transport, QUICConn methods will return an error
|
||||||
|
// which wraps AlertError rather than sending a TLS alert.
|
||||||
|
type AlertError uint8
|
||||||
|
|
||||||
|
func (e AlertError) Error() string {
|
||||||
|
return alert(e).String()
|
||||||
|
}
|
||||||
|
|
||||||
type alert uint8
|
type alert uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -99,6 +99,7 @@ const (
|
||||||
extensionCertificateAuthorities uint16 = 47
|
extensionCertificateAuthorities uint16 = 47
|
||||||
extensionSignatureAlgorithmsCert uint16 = 50
|
extensionSignatureAlgorithmsCert uint16 = 50
|
||||||
extensionKeyShare uint16 = 51
|
extensionKeyShare uint16 = 51
|
||||||
|
extensionQUICTransportParameters uint16 = 57
|
||||||
extensionRenegotiationInfo uint16 = 0xff01
|
extensionRenegotiationInfo uint16 = 0xff01
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
119
conn.go
119
conn.go
|
@ -29,6 +29,7 @@ type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
isClient bool
|
isClient bool
|
||||||
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
||||||
|
quic *quicState // nil for non-QUIC connections
|
||||||
|
|
||||||
// isHandshakeComplete is true if the connection is currently transferring
|
// isHandshakeComplete is true if the connection is currently transferring
|
||||||
// application data (i.e. is not currently processing a handshake).
|
// application data (i.e. is not currently processing a handshake).
|
||||||
|
@ -176,7 +177,8 @@ type halfConn struct {
|
||||||
nextCipher any // next encryption state
|
nextCipher any // next encryption state
|
||||||
nextMac hash.Hash // next MAC algorithm
|
nextMac hash.Hash // next MAC algorithm
|
||||||
|
|
||||||
trafficSecret []byte // current TLS 1.3 traffic secret
|
level QUICEncryptionLevel // current QUIC encryption level
|
||||||
|
trafficSecret []byte // current TLS 1.3 traffic secret
|
||||||
}
|
}
|
||||||
|
|
||||||
type permanentError struct {
|
type permanentError struct {
|
||||||
|
@ -221,8 +223,9 @@ func (hc *halfConn) changeCipherSpec() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
|
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
|
||||||
hc.trafficSecret = secret
|
hc.trafficSecret = secret
|
||||||
|
hc.level = level
|
||||||
key, iv := suite.trafficKey(secret)
|
key, iv := suite.trafficKey(secret)
|
||||||
hc.cipher = suite.aead(key, iv)
|
hc.cipher = suite.aead(key, iv)
|
||||||
for i := range hc.seq {
|
for i := range hc.seq {
|
||||||
|
@ -613,6 +616,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
|
||||||
}
|
}
|
||||||
c.input.Reset(nil)
|
c.input.Reset(nil)
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
|
||||||
|
}
|
||||||
|
|
||||||
// Read header, payload.
|
// Read header, payload.
|
||||||
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
|
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
|
||||||
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
|
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
|
||||||
|
@ -702,6 +709,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
|
||||||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||||
|
|
||||||
case recordTypeAlert:
|
case recordTypeAlert:
|
||||||
|
if c.quic != nil {
|
||||||
|
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||||
|
}
|
||||||
if len(data) != 2 {
|
if len(data) != 2 {
|
||||||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||||
}
|
}
|
||||||
|
@ -819,6 +829,10 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
|
||||||
|
|
||||||
// sendAlertLocked sends a TLS alert message.
|
// sendAlertLocked sends a TLS alert message.
|
||||||
func (c *Conn) sendAlertLocked(err alert) error {
|
func (c *Conn) sendAlertLocked(err alert) error {
|
||||||
|
if c.quic != nil {
|
||||||
|
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
|
||||||
|
}
|
||||||
|
|
||||||
switch err {
|
switch err {
|
||||||
case alertNoRenegotiation, alertCloseNotify:
|
case alertNoRenegotiation, alertCloseNotify:
|
||||||
c.tmp[0] = alertLevelWarning
|
c.tmp[0] = alertLevelWarning
|
||||||
|
@ -953,6 +967,19 @@ var outBufPool = sync.Pool{
|
||||||
// writeRecordLocked writes a TLS record with the given type and payload to the
|
// writeRecordLocked writes a TLS record with the given type and payload to the
|
||||||
// connection and updates the record layer state.
|
// connection and updates the record layer state.
|
||||||
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
||||||
|
if c.quic != nil {
|
||||||
|
if typ != recordTypeHandshake {
|
||||||
|
return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
|
||||||
|
}
|
||||||
|
c.quicWriteCryptoData(c.out.level, data)
|
||||||
|
if !c.buffering {
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
outBufPtr := outBufPool.Get().(*[]byte)
|
outBufPtr := outBufPool.Get().(*[]byte)
|
||||||
outBuf := *outBufPtr
|
outBuf := *outBufPtr
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -1037,28 +1064,40 @@ func (c *Conn) writeChangeCipherRecord() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
|
||||||
|
func (c *Conn) readHandshakeBytes(n int) error {
|
||||||
|
if c.quic != nil {
|
||||||
|
return c.quicReadHandshakeBytes(n)
|
||||||
|
}
|
||||||
|
for c.hand.Len() < n {
|
||||||
|
if err := c.readRecord(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// readHandshake reads the next handshake message from
|
// readHandshake reads the next handshake message from
|
||||||
// the record layer. If transcript is non-nil, the message
|
// the record layer. If transcript is non-nil, the message
|
||||||
// is written to the passed transcriptHash.
|
// is written to the passed transcriptHash.
|
||||||
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
|
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
|
||||||
for c.hand.Len() < 4 {
|
if err := c.readHandshakeBytes(4); err != nil {
|
||||||
if err := c.readRecord(); err != nil {
|
return nil, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data := c.hand.Bytes()
|
data := c.hand.Bytes()
|
||||||
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||||
if n > maxHandshake {
|
if n > maxHandshake {
|
||||||
c.sendAlertLocked(alertInternalError)
|
c.sendAlertLocked(alertInternalError)
|
||||||
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
|
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
|
||||||
}
|
}
|
||||||
for c.hand.Len() < 4+n {
|
if err := c.readHandshakeBytes(4 + n); err != nil {
|
||||||
if err := c.readRecord(); err != nil {
|
return nil, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
data = c.hand.Next(4 + n)
|
data = c.hand.Next(4 + n)
|
||||||
|
return c.unmarshalHandshakeMessage(data, transcript)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
|
||||||
var m handshakeMessage
|
var m handshakeMessage
|
||||||
switch data[0] {
|
switch data[0] {
|
||||||
case typeHelloRequest:
|
case typeHelloRequest:
|
||||||
|
@ -1249,7 +1288,6 @@ func (c *Conn) handlePostHandshakeMessage() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.retryCount++
|
c.retryCount++
|
||||||
if c.retryCount > maxUselessRecords {
|
if c.retryCount > maxUselessRecords {
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
@ -1261,20 +1299,28 @@ func (c *Conn) handlePostHandshakeMessage() error {
|
||||||
return c.handleNewSessionTicket(msg)
|
return c.handleNewSessionTicket(msg)
|
||||||
case *keyUpdateMsg:
|
case *keyUpdateMsg:
|
||||||
return c.handleKeyUpdate(msg)
|
return c.handleKeyUpdate(msg)
|
||||||
default:
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
|
|
||||||
}
|
}
|
||||||
|
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
|
||||||
|
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
|
||||||
|
// unexpected_message alert here doesn't provide it with enough information to distinguish
|
||||||
|
// this condition from other unexpected messages. This is probably fine.
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
||||||
|
if c.quic != nil {
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
|
||||||
|
}
|
||||||
|
|
||||||
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
|
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
|
||||||
if cipherSuite == nil {
|
if cipherSuite == nil {
|
||||||
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
|
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
|
||||||
}
|
}
|
||||||
|
|
||||||
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
|
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
|
||||||
c.in.setTrafficSecret(cipherSuite, newSecret)
|
c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
|
||||||
|
|
||||||
if keyUpdate.updateRequested {
|
if keyUpdate.updateRequested {
|
||||||
c.out.Lock()
|
c.out.Lock()
|
||||||
|
@ -1293,7 +1339,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
|
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
|
||||||
c.out.setTrafficSecret(cipherSuite, newSecret)
|
c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1454,12 +1500,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
||||||
// this cancellation. In the former case, we need to close the connection.
|
// this cancellation. In the former case, we need to close the connection.
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the "interrupter" goroutine, if this context might be canceled.
|
if c.quic != nil {
|
||||||
// (The background context cannot).
|
c.quic.cancelc = handshakeCtx.Done()
|
||||||
//
|
c.quic.cancel = cancel
|
||||||
// The interrupter goroutine waits for the input context to be done and
|
} else if ctx.Done() != nil {
|
||||||
// closes the connection if this happens before the function returns.
|
// Start the "interrupter" goroutine, if this context might be canceled.
|
||||||
if ctx.Done() != nil {
|
// (The background context cannot).
|
||||||
|
//
|
||||||
|
// The interrupter goroutine waits for the input context to be done and
|
||||||
|
// closes the connection if this happens before the function returns.
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
interruptRes := make(chan error, 1)
|
interruptRes := make(chan error, 1)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -1510,6 +1559,30 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
||||||
panic("tls: internal error: handshake returned an error but is marked successful")
|
panic("tls: internal error: handshake returned an error but is marked successful")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if c.handshakeErr == nil {
|
||||||
|
c.quicHandshakeComplete()
|
||||||
|
// Provide the 1-RTT read secret now that the handshake is complete.
|
||||||
|
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
|
||||||
|
// the handshake (RFC 9001, Section 5.7).
|
||||||
|
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
|
||||||
|
} else {
|
||||||
|
var a alert
|
||||||
|
c.out.Lock()
|
||||||
|
if !errors.As(c.out.err, &a) {
|
||||||
|
a = alertInternalError
|
||||||
|
}
|
||||||
|
c.out.Unlock()
|
||||||
|
// Return an error which wraps both the handshake error and
|
||||||
|
// any alert error we may have sent, or alertInternalError
|
||||||
|
// if we didn't send an alert.
|
||||||
|
// Truncate the text of the alert to 0 characters.
|
||||||
|
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
|
||||||
|
}
|
||||||
|
close(c.quic.blockedc)
|
||||||
|
close(c.quic.signalc)
|
||||||
|
}
|
||||||
|
|
||||||
return c.handshakeErr
|
return c.handshakeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
|
||||||
vers: clientHelloVersion,
|
vers: clientHelloVersion,
|
||||||
compressionMethods: []uint8{compressionNone},
|
compressionMethods: []uint8{compressionNone},
|
||||||
random: make([]byte, 32),
|
random: make([]byte, 32),
|
||||||
sessionId: make([]byte, 32),
|
|
||||||
ocspStapling: true,
|
ocspStapling: true,
|
||||||
scts: true,
|
scts: true,
|
||||||
serverName: hostnameInSNI(config.ServerName),
|
serverName: hostnameInSNI(config.ServerName),
|
||||||
|
@ -114,8 +113,13 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
|
||||||
// A random session ID is used to detect when the server accepted a ticket
|
// A random session ID is used to detect when the server accepted a ticket
|
||||||
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
|
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
|
||||||
// a compatibility measure (see RFC 8446, Section 4.1.2).
|
// a compatibility measure (see RFC 8446, Section 4.1.2).
|
||||||
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
|
//
|
||||||
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
|
// The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
|
||||||
|
if c.quic == nil {
|
||||||
|
hello.sessionId = make([]byte, 32)
|
||||||
|
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
|
||||||
|
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hello.vers >= VersionTLS12 {
|
if hello.vers >= VersionTLS12 {
|
||||||
|
@ -144,6 +148,17 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
|
||||||
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
|
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
p, err := c.quicGetTransportParameters()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if p == nil {
|
||||||
|
p = []byte{}
|
||||||
|
}
|
||||||
|
hello.quicTransportParameters = p
|
||||||
|
}
|
||||||
|
|
||||||
return hello, key, nil
|
return hello, key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,7 +286,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to resume a previously negotiated TLS session, if available.
|
// Try to resume a previously negotiated TLS session, if available.
|
||||||
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
cacheKey = c.clientSessionCacheKey()
|
||||||
|
if cacheKey == "" {
|
||||||
|
return "", nil, nil, nil, nil
|
||||||
|
}
|
||||||
session, ok := c.config.ClientSessionCache.Get(cacheKey)
|
session, ok := c.config.ClientSessionCache.Get(cacheKey)
|
||||||
if !ok || session == nil {
|
if !ok || session == nil {
|
||||||
return cacheKey, nil, nil, nil, nil
|
return cacheKey, nil, nil, nil, nil
|
||||||
|
@ -722,7 +740,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
|
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
c.sendAlert(alertUnsupportedExtension)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -760,8 +778,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
||||||
|
|
||||||
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
|
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
|
||||||
// the protocols that we advertised in the Client Hello.
|
// the protocols that we advertised in the Client Hello.
|
||||||
func checkALPN(clientProtos []string, serverProto string) error {
|
func checkALPN(clientProtos []string, serverProto string, quic bool) error {
|
||||||
if serverProto == "" {
|
if serverProto == "" {
|
||||||
|
if quic && len(clientProtos) > 0 {
|
||||||
|
// RFC 9001, Section 8.1
|
||||||
|
return errors.New("tls: server did not select an ALPN protocol")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(clientProtos) == 0 {
|
if len(clientProtos) == 0 {
|
||||||
|
@ -1003,11 +1025,14 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate,
|
||||||
|
|
||||||
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
||||||
// be used to resume previously negotiated TLS sessions with a server.
|
// be used to resume previously negotiated TLS sessions with a server.
|
||||||
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
|
func (c *Conn) clientSessionCacheKey() string {
|
||||||
if len(config.ServerName) > 0 {
|
if len(c.config.ServerName) > 0 {
|
||||||
return config.ServerName
|
return c.config.ServerName
|
||||||
}
|
}
|
||||||
return serverAddr.String()
|
if c.conn != nil {
|
||||||
|
return c.conn.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// hostnameInSNI converts name into an appropriate hostname for SNI.
|
// hostnameInSNI converts name into an appropriate hostname for SNI.
|
||||||
|
|
|
@ -172,6 +172,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
|
||||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||||
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||||
|
if hs.c.quic != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if hs.sentDummyCCS {
|
if hs.sentDummyCCS {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -383,10 +386,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
|
||||||
|
|
||||||
clientSecret := hs.suite.deriveSecret(handshakeSecret,
|
clientSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||||
clientHandshakeTrafficLabel, hs.transcript)
|
clientHandshakeTrafficLabel, hs.transcript)
|
||||||
c.out.setTrafficSecret(hs.suite, clientSecret)
|
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
||||||
serverSecret := hs.suite.deriveSecret(handshakeSecret,
|
serverSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||||
serverHandshakeTrafficLabel, hs.transcript)
|
serverHandshakeTrafficLabel, hs.transcript)
|
||||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if c.hand.Len() != 0 {
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
}
|
||||||
|
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
||||||
|
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
||||||
|
}
|
||||||
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
|
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -419,12 +430,30 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
|
||||||
return unexpectedMessageError(encryptedExtensions, msg)
|
return unexpectedMessageError(encryptedExtensions, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
|
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
// RFC 8446 specifies that no_application_protocol is sent by servers, but
|
||||||
|
// does not specify how clients handle the selection of an incompatible protocol.
|
||||||
|
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
|
||||||
|
// in this case. Always sending no_application_protocol seems reasonable.
|
||||||
|
c.sendAlert(alertNoApplicationProtocol)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.clientProtocol = encryptedExtensions.alpnProtocol
|
c.clientProtocol = encryptedExtensions.alpnProtocol
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if encryptedExtensions.quicTransportParameters == nil {
|
||||||
|
// RFC 9001 Section 8.2.
|
||||||
|
c.sendAlert(alertMissingExtension)
|
||||||
|
return errors.New("tls: server did not send a quic_transport_parameters extension")
|
||||||
|
}
|
||||||
|
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
|
||||||
|
} else {
|
||||||
|
if encryptedExtensions.quicTransportParameters != nil {
|
||||||
|
c.sendAlert(alertUnsupportedExtension)
|
||||||
|
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -552,7 +581,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
|
||||||
clientApplicationTrafficLabel, hs.transcript)
|
clientApplicationTrafficLabel, hs.transcript)
|
||||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||||
serverApplicationTrafficLabel, hs.transcript)
|
serverApplicationTrafficLabel, hs.transcript)
|
||||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
||||||
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
|
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -648,13 +677,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
|
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
||||||
|
|
||||||
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
|
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
|
||||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||||
resumptionLabel, hs.transcript)
|
resumptionLabel, hs.transcript)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if c.hand.Len() != 0 {
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
}
|
||||||
|
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -702,8 +738,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
|
||||||
scts: c.scts,
|
scts: c.scts,
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
cacheKey := c.clientSessionCacheKey()
|
||||||
c.config.ClientSessionCache.Put(cacheKey, session)
|
if cacheKey != "" {
|
||||||
|
c.config.ClientSessionCache.Put(cacheKey, session)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,6 +93,7 @@ type clientHelloMsg struct {
|
||||||
pskModes []uint8
|
pskModes []uint8
|
||||||
pskIdentities []pskIdentity
|
pskIdentities []pskIdentity
|
||||||
pskBinders [][]byte
|
pskBinders [][]byte
|
||||||
|
quicTransportParameters []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *clientHelloMsg) marshal() ([]byte, error) {
|
func (m *clientHelloMsg) marshal() ([]byte, error) {
|
||||||
|
@ -246,6 +247,13 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
|
||||||
|
// RFC 9001, Section 8.2
|
||||||
|
exts.AddUint16(extensionQUICTransportParameters)
|
||||||
|
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
|
||||||
|
exts.AddBytes(m.quicTransportParameters)
|
||||||
|
})
|
||||||
|
}
|
||||||
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
|
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
|
||||||
// RFC 8446, Section 4.2.11
|
// RFC 8446, Section 4.2.11
|
||||||
exts.AddUint16(extensionPreSharedKey)
|
exts.AddUint16(extensionPreSharedKey)
|
||||||
|
@ -560,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||||
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
|
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
case extensionQUICTransportParameters:
|
||||||
|
m.quicTransportParameters = make([]byte, len(extData))
|
||||||
|
if !extData.CopyBytes(m.quicTransportParameters) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
case extensionPreSharedKey:
|
case extensionPreSharedKey:
|
||||||
// RFC 8446, Section 4.2.11
|
// RFC 8446, Section 4.2.11
|
||||||
if !extensions.Empty() {
|
if !extensions.Empty() {
|
||||||
|
@ -860,8 +873,9 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type encryptedExtensionsMsg struct {
|
type encryptedExtensionsMsg struct {
|
||||||
raw []byte
|
raw []byte
|
||||||
alpnProtocol string
|
alpnProtocol string
|
||||||
|
quicTransportParameters []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
|
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
|
||||||
|
@ -883,6 +897,13 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
|
||||||
|
// draft-ietf-quic-tls-32, Section 8.2
|
||||||
|
b.AddUint16(extensionQUICTransportParameters)
|
||||||
|
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes(m.quicTransportParameters)
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -921,6 +942,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
m.alpnProtocol = string(proto)
|
m.alpnProtocol = string(proto)
|
||||||
|
case extensionQUICTransportParameters:
|
||||||
|
m.quicTransportParameters = make([]byte, len(extData))
|
||||||
|
if !extData.CopyBytes(m.quicTransportParameters) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// Ignore unknown extensions.
|
// Ignore unknown extensions.
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -197,6 +197,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
m.pskIdentities = append(m.pskIdentities, psk)
|
m.pskIdentities = append(m.pskIdentities, psk)
|
||||||
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
|
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
|
||||||
}
|
}
|
||||||
|
if rand.Intn(10) > 5 {
|
||||||
|
m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
|
||||||
|
}
|
||||||
if rand.Intn(10) > 5 {
|
if rand.Intn(10) > 5 {
|
||||||
m.earlyData = true
|
m.earlyData = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error {
|
||||||
c.serverName = hs.clientHello.serverName
|
c.serverName = hs.clientHello.serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertNoApplicationProtocol)
|
c.sendAlert(alertNoApplicationProtocol)
|
||||||
return err
|
return err
|
||||||
|
@ -279,8 +279,12 @@ func (hs *serverHandshakeState) processClientHello() error {
|
||||||
// negotiateALPN picks a shared ALPN protocol that both sides support in server
|
// negotiateALPN picks a shared ALPN protocol that both sides support in server
|
||||||
// preference order. If ALPN is not configured or the peer doesn't support it,
|
// preference order. If ALPN is not configured or the peer doesn't support it,
|
||||||
// it returns "" and no error.
|
// it returns "" and no error.
|
||||||
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
|
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
|
||||||
if len(serverProtos) == 0 || len(clientProtos) == 0 {
|
if len(serverProtos) == 0 || len(clientProtos) == 0 {
|
||||||
|
if quic && len(serverProtos) != 0 {
|
||||||
|
// RFC 9001, Section 8.1
|
||||||
|
return "", fmt.Errorf("tls: client did not request an application protocol")
|
||||||
|
}
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
var http11fallback bool
|
var http11fallback bool
|
||||||
|
|
|
@ -226,6 +226,20 @@ GroupSelection:
|
||||||
return errors.New("tls: invalid client key share")
|
return errors.New("tls: invalid client key share")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if hs.clientHello.quicTransportParameters == nil {
|
||||||
|
// RFC 9001 Section 8.2.
|
||||||
|
c.sendAlert(alertMissingExtension)
|
||||||
|
return errors.New("tls: client did not send a quic_transport_parameters extension")
|
||||||
|
}
|
||||||
|
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
|
||||||
|
} else {
|
||||||
|
if hs.clientHello.quicTransportParameters != nil {
|
||||||
|
c.sendAlert(alertUnsupportedExtension)
|
||||||
|
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.serverName = hs.clientHello.serverName
|
c.serverName = hs.clientHello.serverName
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -397,6 +411,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
||||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||||
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||||
|
if hs.c.quic != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if hs.sentDummyCCS {
|
if hs.sentDummyCCS {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -548,10 +565,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
||||||
|
|
||||||
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||||
clientHandshakeTrafficLabel, hs.transcript)
|
clientHandshakeTrafficLabel, hs.transcript)
|
||||||
c.in.setTrafficSecret(hs.suite, clientSecret)
|
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
||||||
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||||
serverHandshakeTrafficLabel, hs.transcript)
|
serverHandshakeTrafficLabel, hs.transcript)
|
||||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if c.hand.Len() != 0 {
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
}
|
||||||
|
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
||||||
|
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
|
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -566,7 +591,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
||||||
|
|
||||||
encryptedExtensions := new(encryptedExtensionsMsg)
|
encryptedExtensions := new(encryptedExtensionsMsg)
|
||||||
|
|
||||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertNoApplicationProtocol)
|
c.sendAlert(alertNoApplicationProtocol)
|
||||||
return err
|
return err
|
||||||
|
@ -574,6 +599,14 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
||||||
encryptedExtensions.alpnProtocol = selectedProto
|
encryptedExtensions.alpnProtocol = selectedProto
|
||||||
c.clientProtocol = selectedProto
|
c.clientProtocol = selectedProto
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
p, err := c.quicGetTransportParameters()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
encryptedExtensions.quicTransportParameters = p
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
|
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -672,7 +705,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
|
||||||
clientApplicationTrafficLabel, hs.transcript)
|
clientApplicationTrafficLabel, hs.transcript)
|
||||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||||
serverApplicationTrafficLabel, hs.transcript)
|
serverApplicationTrafficLabel, hs.transcript)
|
||||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
||||||
|
|
||||||
|
if c.quic != nil {
|
||||||
|
if c.hand.Len() != 0 {
|
||||||
|
// TODO: Handle this in setTrafficSecret?
|
||||||
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
}
|
||||||
|
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
|
||||||
|
}
|
||||||
|
|
||||||
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
|
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -887,7 +928,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
|
||||||
return errors.New("tls: invalid client finished hash")
|
return errors.New("tls: invalid client finished hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
|
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
376
quic.go
Normal file
376
quic.go
Normal file
|
@ -0,0 +1,376 @@
|
||||||
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
|
||||||
|
// handshake messages.
|
||||||
|
type QUICEncryptionLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
|
||||||
|
QUICEncryptionLevelHandshake
|
||||||
|
QUICEncryptionLevelApplication
|
||||||
|
)
|
||||||
|
|
||||||
|
func (l QUICEncryptionLevel) String() string {
|
||||||
|
switch l {
|
||||||
|
case QUICEncryptionLevelInitial:
|
||||||
|
return "Initial"
|
||||||
|
case QUICEncryptionLevelHandshake:
|
||||||
|
return "Handshake"
|
||||||
|
case QUICEncryptionLevelApplication:
|
||||||
|
return "Application"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
|
||||||
|
// transport as described in RFC 9001.
|
||||||
|
//
|
||||||
|
// Methods of QUICConn are not safe for concurrent use.
|
||||||
|
type QUICConn struct {
|
||||||
|
conn *Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// A QUICConfig configures a QUICConn.
|
||||||
|
type QUICConfig struct {
|
||||||
|
TLSConfig *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// A QUICEventKind is a type of operation on a QUIC connection.
|
||||||
|
type QUICEventKind int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// QUICNoEvent indicates that there are no events available.
|
||||||
|
QUICNoEvent QUICEventKind = iota
|
||||||
|
|
||||||
|
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
|
||||||
|
// secrets for a given encryption level.
|
||||||
|
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
|
||||||
|
//
|
||||||
|
// Secrets for the Initial encryption level are derived from the initial
|
||||||
|
// destination connection ID, and are not provided by the QUICConn.
|
||||||
|
QUICSetReadSecret
|
||||||
|
QUICSetWriteSecret
|
||||||
|
|
||||||
|
// QUICWriteData provides data to send to the peer in CRYPTO frames.
|
||||||
|
// QUICEvent.Data is set.
|
||||||
|
QUICWriteData
|
||||||
|
|
||||||
|
// QUICTransportParameters provides the peer's QUIC transport parameters.
|
||||||
|
// QUICEvent.Data is set.
|
||||||
|
QUICTransportParameters
|
||||||
|
|
||||||
|
// QUICTransportParametersRequired indicates that the caller must provide
|
||||||
|
// QUIC transport parameters to send to the peer. The caller should set
|
||||||
|
// the transport parameters with QUICConn.SetTransportParameters and call
|
||||||
|
// QUICConn.NextEvent again.
|
||||||
|
//
|
||||||
|
// If transport parameters are set before calling QUICConn.Start, the
|
||||||
|
// connection will never generate a QUICTransportParametersRequired event.
|
||||||
|
QUICTransportParametersRequired
|
||||||
|
|
||||||
|
// QUICHandshakeDone indicates that the TLS handshake has completed.
|
||||||
|
QUICHandshakeDone
|
||||||
|
)
|
||||||
|
|
||||||
|
// A QUICEvent is an event occurring on a QUIC connection.
|
||||||
|
//
|
||||||
|
// The type of event is specified by the Kind field.
|
||||||
|
// The contents of the other fields are kind-specific.
|
||||||
|
type QUICEvent struct {
|
||||||
|
Kind QUICEventKind
|
||||||
|
|
||||||
|
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
||||||
|
Level QUICEncryptionLevel
|
||||||
|
|
||||||
|
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
||||||
|
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
|
||||||
|
Data []byte
|
||||||
|
|
||||||
|
// Set for QUICSetReadSecret and QUICSetWriteSecret.
|
||||||
|
Suite uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type quicState struct {
|
||||||
|
events []QUICEvent
|
||||||
|
nextEvent int
|
||||||
|
|
||||||
|
// eventArr is a statically allocated event array, large enough to handle
|
||||||
|
// the usual maximum number of events resulting from a single call:
|
||||||
|
// transport parameters, Initial data, Handshake write and read secrets,
|
||||||
|
// Handshake data, Application write secret, Application data.
|
||||||
|
eventArr [7]QUICEvent
|
||||||
|
|
||||||
|
started bool
|
||||||
|
signalc chan struct{} // handshake data is available to be read
|
||||||
|
blockedc chan struct{} // handshake is waiting for data, closed when done
|
||||||
|
cancelc <-chan struct{} // handshake has been canceled
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
// readbuf is shared between HandleData and the handshake goroutine.
|
||||||
|
// HandshakeCryptoData passes ownership to the handshake goroutine by
|
||||||
|
// reading from signalc, and reclaims ownership by reading from blockedc.
|
||||||
|
readbuf []byte
|
||||||
|
|
||||||
|
transportParams []byte // to send to the peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// QUICClient returns a new TLS client side connection using QUICTransport as the
|
||||||
|
// underlying transport. The config cannot be nil.
|
||||||
|
//
|
||||||
|
// The config's MinVersion must be at least TLS 1.3.
|
||||||
|
func QUICClient(config *QUICConfig) *QUICConn {
|
||||||
|
return newQUICConn(Client(nil, config.TLSConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QUICServer returns a new TLS server side connection using QUICTransport as the
|
||||||
|
// underlying transport. The config cannot be nil.
|
||||||
|
//
|
||||||
|
// The config's MinVersion must be at least TLS 1.3.
|
||||||
|
func QUICServer(config *QUICConfig) *QUICConn {
|
||||||
|
return newQUICConn(Server(nil, config.TLSConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newQUICConn(conn *Conn) *QUICConn {
|
||||||
|
conn.quic = &quicState{
|
||||||
|
signalc: make(chan struct{}),
|
||||||
|
blockedc: make(chan struct{}),
|
||||||
|
}
|
||||||
|
conn.quic.events = conn.quic.eventArr[:0]
|
||||||
|
return &QUICConn{
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the client or server handshake protocol.
|
||||||
|
// It may produce connection events, which may be read with NextEvent.
|
||||||
|
//
|
||||||
|
// Start must be called at most once.
|
||||||
|
func (q *QUICConn) Start(ctx context.Context) error {
|
||||||
|
if q.conn.quic.started {
|
||||||
|
return quicError(errors.New("tls: Start called more than once"))
|
||||||
|
}
|
||||||
|
q.conn.quic.started = true
|
||||||
|
if q.conn.config.MinVersion < VersionTLS13 {
|
||||||
|
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
|
||||||
|
}
|
||||||
|
go q.conn.HandshakeContext(ctx)
|
||||||
|
if _, ok := <-q.conn.quic.blockedc; !ok {
|
||||||
|
return q.conn.handshakeErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextEvent returns the next event occurring on the connection.
|
||||||
|
// It returns an event with a Kind of QUICNoEvent when no events are available.
|
||||||
|
func (q *QUICConn) NextEvent() QUICEvent {
|
||||||
|
qs := q.conn.quic
|
||||||
|
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
|
||||||
|
// Write over some of the previous event's data,
|
||||||
|
// to catch callers erroniously retaining it.
|
||||||
|
qs.events[last].Data[0] = 0
|
||||||
|
}
|
||||||
|
if qs.nextEvent >= len(qs.events) {
|
||||||
|
qs.events = qs.events[:0]
|
||||||
|
qs.nextEvent = 0
|
||||||
|
return QUICEvent{Kind: QUICNoEvent}
|
||||||
|
}
|
||||||
|
e := qs.events[qs.nextEvent]
|
||||||
|
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
|
||||||
|
qs.nextEvent++
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection and stops any in-progress handshake.
|
||||||
|
func (q *QUICConn) Close() error {
|
||||||
|
if q.conn.quic.cancel == nil {
|
||||||
|
return nil // never started
|
||||||
|
}
|
||||||
|
q.conn.quic.cancel()
|
||||||
|
for range q.conn.quic.blockedc {
|
||||||
|
// Wait for the handshake goroutine to return.
|
||||||
|
}
|
||||||
|
return q.conn.handshakeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleData handles handshake bytes received from the peer.
|
||||||
|
// It may produce connection events, which may be read with NextEvent.
|
||||||
|
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
|
||||||
|
c := q.conn
|
||||||
|
if c.in.level != level {
|
||||||
|
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
|
||||||
|
}
|
||||||
|
c.quic.readbuf = data
|
||||||
|
<-c.quic.signalc
|
||||||
|
_, ok := <-c.quic.blockedc
|
||||||
|
if ok {
|
||||||
|
// The handshake goroutine is waiting for more data.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// The handshake goroutine has exited.
|
||||||
|
c.hand.Write(c.quic.readbuf)
|
||||||
|
c.quic.readbuf = nil
|
||||||
|
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
|
||||||
|
b := q.conn.hand.Bytes()
|
||||||
|
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||||
|
if 4+n < len(b) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := q.conn.handlePostHandshakeMessage(); err != nil {
|
||||||
|
return quicError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if q.conn.handshakeErr != nil {
|
||||||
|
return quicError(q.conn.handshakeErr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState returns basic TLS details about the connection.
|
||||||
|
func (q *QUICConn) ConnectionState() ConnectionState {
|
||||||
|
return q.conn.ConnectionState()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTransportParameters sets the transport parameters to send to the peer.
|
||||||
|
//
|
||||||
|
// Server connections may delay setting the transport parameters until after
|
||||||
|
// receiving the client's transport parameters. See QUICTransportParametersRequired.
|
||||||
|
func (q *QUICConn) SetTransportParameters(params []byte) {
|
||||||
|
if params == nil {
|
||||||
|
params = []byte{}
|
||||||
|
}
|
||||||
|
q.conn.quic.transportParams = params
|
||||||
|
if q.conn.quic.started {
|
||||||
|
<-q.conn.quic.signalc
|
||||||
|
<-q.conn.quic.blockedc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// quicError ensures err is an AlertError.
|
||||||
|
// If err is not already, quicError wraps it with alertInternalError.
|
||||||
|
func quicError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var ae AlertError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var a alert
|
||||||
|
if !errors.As(err, &a) {
|
||||||
|
a = alertInternalError
|
||||||
|
}
|
||||||
|
// Return an error wrapping the original error and an AlertError.
|
||||||
|
// Truncate the text of the alert to 0 characters.
|
||||||
|
return fmt.Errorf("%w%.0w", err, AlertError(a))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicReadHandshakeBytes(n int) error {
|
||||||
|
for c.hand.Len() < n {
|
||||||
|
if err := c.quicWaitForSignal(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICSetReadSecret,
|
||||||
|
Level: level,
|
||||||
|
Suite: suite,
|
||||||
|
Data: secret,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICSetWriteSecret,
|
||||||
|
Level: level,
|
||||||
|
Suite: suite,
|
||||||
|
Data: secret,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
|
||||||
|
var last *QUICEvent
|
||||||
|
if len(c.quic.events) > 0 {
|
||||||
|
last = &c.quic.events[len(c.quic.events)-1]
|
||||||
|
}
|
||||||
|
if last == nil || last.Kind != QUICWriteData || last.Level != level {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICWriteData,
|
||||||
|
Level: level,
|
||||||
|
})
|
||||||
|
last = &c.quic.events[len(c.quic.events)-1]
|
||||||
|
}
|
||||||
|
last.Data = append(last.Data, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicSetTransportParameters(params []byte) {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICTransportParameters,
|
||||||
|
Data: params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
|
||||||
|
if c.quic.transportParams == nil {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICTransportParametersRequired,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for c.quic.transportParams == nil {
|
||||||
|
if err := c.quicWaitForSignal(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.quic.transportParams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) quicHandshakeComplete() {
|
||||||
|
c.quic.events = append(c.quic.events, QUICEvent{
|
||||||
|
Kind: QUICHandshakeDone,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
|
||||||
|
// and waits for a signal that the handshake should proceed.
|
||||||
|
//
|
||||||
|
// The handshake may become blocked waiting for handshake bytes
|
||||||
|
// or for the user to provide transport parameters.
|
||||||
|
func (c *Conn) quicWaitForSignal() error {
|
||||||
|
// Drop the handshake mutex while blocked to allow the user
|
||||||
|
// to call ConnectionState before the handshake completes.
|
||||||
|
c.handshakeMutex.Unlock()
|
||||||
|
defer c.handshakeMutex.Lock()
|
||||||
|
// Send on blockedc to notify the QUICConn that the handshake is blocked.
|
||||||
|
// Exported methods of QUICConn wait for the handshake to become blocked
|
||||||
|
// before returning to the user.
|
||||||
|
select {
|
||||||
|
case c.quic.blockedc <- struct{}{}:
|
||||||
|
case <-c.quic.cancelc:
|
||||||
|
return c.sendAlertLocked(alertCloseNotify)
|
||||||
|
}
|
||||||
|
// The QUICConn reads from signalc to notify us that the handshake may
|
||||||
|
// be able to proceed. (The QUICConn reads, because we close signalc to
|
||||||
|
// indicate that the handshake has completed.)
|
||||||
|
select {
|
||||||
|
case c.quic.signalc <- struct{}{}:
|
||||||
|
c.hand.Write(c.quic.readbuf)
|
||||||
|
c.quic.readbuf = nil
|
||||||
|
case <-c.quic.cancelc:
|
||||||
|
return c.sendAlertLocked(alertCloseNotify)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
430
quic_test.go
Normal file
430
quic_test.go
Normal file
|
@ -0,0 +1,430 @@
|
||||||
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testQUICConn struct {
|
||||||
|
t *testing.T
|
||||||
|
conn *QUICConn
|
||||||
|
readSecret map[QUICEncryptionLevel]suiteSecret
|
||||||
|
writeSecret map[QUICEncryptionLevel]suiteSecret
|
||||||
|
gotParams []byte
|
||||||
|
complete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
|
||||||
|
q := &testQUICConn{t: t}
|
||||||
|
q.conn = QUICClient(&QUICConfig{
|
||||||
|
TLSConfig: config,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
q.conn.Close()
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
|
||||||
|
q := &testQUICConn{t: t}
|
||||||
|
q.conn = QUICServer(&QUICConfig{
|
||||||
|
TLSConfig: config,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
q.conn.Close()
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
type suiteSecret struct {
|
||||||
|
suite uint16
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||||
|
if _, ok := q.writeSecret[level]; !ok {
|
||||||
|
q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
|
||||||
|
}
|
||||||
|
if level == QUICEncryptionLevelApplication && !q.complete {
|
||||||
|
q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
|
||||||
|
}
|
||||||
|
if _, ok := q.readSecret[level]; ok {
|
||||||
|
q.t.Errorf("SetReadSecret for level %v called twice", level)
|
||||||
|
}
|
||||||
|
if q.readSecret == nil {
|
||||||
|
q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
|
||||||
|
}
|
||||||
|
switch level {
|
||||||
|
case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
|
||||||
|
q.readSecret[level] = suiteSecret{suite, secret}
|
||||||
|
default:
|
||||||
|
q.t.Errorf("SetReadSecret for unexpected level %v", level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||||
|
if _, ok := q.writeSecret[level]; ok {
|
||||||
|
q.t.Errorf("SetWriteSecret for level %v called twice", level)
|
||||||
|
}
|
||||||
|
if q.writeSecret == nil {
|
||||||
|
q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
|
||||||
|
}
|
||||||
|
switch level {
|
||||||
|
case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
|
||||||
|
q.writeSecret[level] = suiteSecret{suite, secret}
|
||||||
|
default:
|
||||||
|
q.t.Errorf("SetWriteSecret for unexpected level %v", level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errTransportParametersRequired = errors.New("transport parameters required")
|
||||||
|
|
||||||
|
func runTestQUICConnection(ctx context.Context, a, b *testQUICConn, onHandleCryptoData func()) error {
|
||||||
|
for _, c := range []*testQUICConn{a, b} {
|
||||||
|
if !c.conn.conn.quic.started {
|
||||||
|
if err := c.conn.Start(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idleCount := 0
|
||||||
|
for {
|
||||||
|
e := a.conn.NextEvent()
|
||||||
|
switch e.Kind {
|
||||||
|
case QUICNoEvent:
|
||||||
|
idleCount++
|
||||||
|
if idleCount == 2 {
|
||||||
|
if !a.complete || !b.complete {
|
||||||
|
return errors.New("handshake incomplete")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
a, b = b, a
|
||||||
|
case QUICSetReadSecret:
|
||||||
|
a.setReadSecret(e.Level, e.Suite, e.Data)
|
||||||
|
case QUICSetWriteSecret:
|
||||||
|
a.setWriteSecret(e.Level, e.Suite, e.Data)
|
||||||
|
case QUICWriteData:
|
||||||
|
if err := b.conn.HandleData(e.Level, e.Data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case QUICTransportParameters:
|
||||||
|
a.gotParams = e.Data
|
||||||
|
if a.gotParams == nil {
|
||||||
|
a.gotParams = []byte{}
|
||||||
|
}
|
||||||
|
case QUICTransportParametersRequired:
|
||||||
|
return errTransportParametersRequired
|
||||||
|
case QUICHandshakeDone:
|
||||||
|
a.complete = true
|
||||||
|
}
|
||||||
|
if e.Kind != QUICNoEvent {
|
||||||
|
idleCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICConnection(t *testing.T) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
|
||||||
|
t.Errorf("client has no Handshake secret")
|
||||||
|
}
|
||||||
|
if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
|
||||||
|
t.Errorf("client has no Application secret")
|
||||||
|
}
|
||||||
|
if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
|
||||||
|
t.Errorf("server has no Handshake secret")
|
||||||
|
}
|
||||||
|
if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
|
||||||
|
t.Errorf("server has no Application secret")
|
||||||
|
}
|
||||||
|
for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
|
||||||
|
if _, ok := cli.readSecret[level]; !ok {
|
||||||
|
t.Errorf("client has no %v read secret", level)
|
||||||
|
}
|
||||||
|
if _, ok := srv.readSecret[level]; !ok {
|
||||||
|
t.Errorf("server has no %v read secret", level)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
|
||||||
|
t.Errorf("client read secret does not match server write secret for level %v", level)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
|
||||||
|
t.Errorf("client write secret does not match server read secret for level %v", level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICSessionResumption(t *testing.T) {
|
||||||
|
clientConfig := testConfig.Clone()
|
||||||
|
clientConfig.MinVersion = VersionTLS13
|
||||||
|
clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
|
||||||
|
clientConfig.ServerName = "example.go.dev"
|
||||||
|
|
||||||
|
serverConfig := testConfig.Clone()
|
||||||
|
serverConfig.MinVersion = VersionTLS13
|
||||||
|
|
||||||
|
cli := newTestQUICClient(t, clientConfig)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, serverConfig)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during first connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
if cli.conn.ConnectionState().DidResume {
|
||||||
|
t.Errorf("first connection unexpectedly used session resumption")
|
||||||
|
}
|
||||||
|
|
||||||
|
cli2 := newTestQUICClient(t, clientConfig)
|
||||||
|
cli2.conn.SetTransportParameters(nil)
|
||||||
|
srv2 := newTestQUICServer(t, serverConfig)
|
||||||
|
srv2.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
|
||||||
|
t.Fatalf("error during second connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
if !cli2.conn.ConnectionState().DidResume {
|
||||||
|
t.Errorf("second connection did not use session resumption")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
|
||||||
|
// RFC 9001, Section 4.4.
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certReq := new(certificateRequestMsgTLS13)
|
||||||
|
certReq.ocspStapling = true
|
||||||
|
certReq.scts = true
|
||||||
|
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
|
||||||
|
certReqBytes, err := certReq.marshal()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
|
||||||
|
byte(typeCertificateRequest),
|
||||||
|
byte(0), byte(0), byte(len(certReqBytes)),
|
||||||
|
}, certReqBytes...)); err == nil {
|
||||||
|
t.Fatalf("post-handshake authentication request: got no error, want one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
|
||||||
|
// RFC 9001, Section 6.
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyUpdate := new(keyUpdateMsg)
|
||||||
|
keyUpdateBytes, err := keyUpdate.marshal()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
|
||||||
|
byte(typeKeyUpdate),
|
||||||
|
byte(0), byte(0), byte(len(keyUpdateBytes)),
|
||||||
|
}, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
|
||||||
|
t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICHandshakeError(t *testing.T) {
|
||||||
|
clientConfig := testConfig.Clone()
|
||||||
|
clientConfig.MinVersion = VersionTLS13
|
||||||
|
clientConfig.InsecureSkipVerify = false
|
||||||
|
clientConfig.ServerName = "name"
|
||||||
|
|
||||||
|
serverConfig := testConfig.Clone()
|
||||||
|
serverConfig.MinVersion = VersionTLS13
|
||||||
|
|
||||||
|
cli := newTestQUICClient(t, clientConfig)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, serverConfig)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
err := runTestQUICConnection(context.Background(), cli, srv, nil)
|
||||||
|
if !errors.Is(err, AlertError(alertBadCertificate)) {
|
||||||
|
t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
|
||||||
|
}
|
||||||
|
var e *CertificateVerificationError
|
||||||
|
if !errors.As(err, &e) {
|
||||||
|
t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that QUICConn.ConnectionState can be used during the handshake,
|
||||||
|
// and that it reports the application protocol as soon as it has been
|
||||||
|
// negotiated.
|
||||||
|
func TestQUICConnectionState(t *testing.T) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
config.NextProtos = []string{"h3"}
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
onHandleCryptoData := func() {
|
||||||
|
cliCS := cli.conn.ConnectionState()
|
||||||
|
cliWantALPN := ""
|
||||||
|
if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
|
||||||
|
cliWantALPN = "h3"
|
||||||
|
}
|
||||||
|
if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
|
||||||
|
t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
srvCS := srv.conn.ConnectionState()
|
||||||
|
srvWantALPN := ""
|
||||||
|
if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
|
||||||
|
srvWantALPN = "h3"
|
||||||
|
}
|
||||||
|
if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
|
||||||
|
t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICStartContextPropagation(t *testing.T) {
|
||||||
|
const key = "key"
|
||||||
|
const value = "value"
|
||||||
|
ctx := context.WithValue(context.Background(), key, value)
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
calls := 0
|
||||||
|
config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
|
||||||
|
calls++
|
||||||
|
got, _ := info.Context().Value(key).(string)
|
||||||
|
if got != value {
|
||||||
|
t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Errorf("GetConfigForClient called %v times, want 1", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICDelayedTransportParameters(t *testing.T) {
|
||||||
|
clientConfig := testConfig.Clone()
|
||||||
|
clientConfig.MinVersion = VersionTLS13
|
||||||
|
clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
|
||||||
|
clientConfig.ServerName = "example.go.dev"
|
||||||
|
|
||||||
|
serverConfig := testConfig.Clone()
|
||||||
|
serverConfig.MinVersion = VersionTLS13
|
||||||
|
|
||||||
|
cliParams := "client params"
|
||||||
|
srvParams := "server params"
|
||||||
|
|
||||||
|
cli := newTestQUICClient(t, clientConfig)
|
||||||
|
srv := newTestQUICServer(t, serverConfig)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
|
||||||
|
t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
|
||||||
|
}
|
||||||
|
cli.conn.SetTransportParameters([]byte(cliParams))
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
|
||||||
|
t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
|
||||||
|
}
|
||||||
|
srv.conn.SetTransportParameters([]byte(srvParams))
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := string(cli.gotParams), srvParams; got != want {
|
||||||
|
t.Errorf("client got transport params: %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := string(srv.gotParams), cliParams; got != want {
|
||||||
|
t.Errorf("server got transport params: %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICEmptyTransportParameters(t *testing.T) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
srv := newTestQUICServer(t, config)
|
||||||
|
srv.conn.SetTransportParameters(nil)
|
||||||
|
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
|
||||||
|
t.Fatalf("error during connection handshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cli.gotParams == nil {
|
||||||
|
t.Errorf("client did not get transport params")
|
||||||
|
}
|
||||||
|
if srv.gotParams == nil {
|
||||||
|
t.Errorf("server did not get transport params")
|
||||||
|
}
|
||||||
|
if len(cli.gotParams) != 0 {
|
||||||
|
t.Errorf("client got transport params: %v, want empty", cli.gotParams)
|
||||||
|
}
|
||||||
|
if len(srv.gotParams) != 0 {
|
||||||
|
t.Errorf("server got transport params: %v, want empty", srv.gotParams)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICCanceledWaitingForData(t *testing.T) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.SetTransportParameters(nil)
|
||||||
|
cli.conn.Start(context.Background())
|
||||||
|
for cli.conn.NextEvent().Kind != QUICNoEvent {
|
||||||
|
}
|
||||||
|
err := cli.conn.Close()
|
||||||
|
if !errors.Is(err, alertCloseNotify) {
|
||||||
|
t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.MinVersion = VersionTLS13
|
||||||
|
cli := newTestQUICClient(t, config)
|
||||||
|
cli.conn.Start(context.Background())
|
||||||
|
for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
|
||||||
|
}
|
||||||
|
err := cli.conn.Close()
|
||||||
|
if !errors.Is(err, alertCloseNotify) {
|
||||||
|
t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue