crypto/tls: ensure the server picked an advertised ALPN protocol

This is a SHALL in RFC 7301, Section 3.2.

Also some more cleanup after NPN, which worked the other way around
(with the possibility that the client could pick a protocol the server
did not suggest).

Change-Id: I83cc43ca1b3c686dfece8315436441c077065d82
Reviewed-on: https://go-review.googlesource.com/c/go/+/239748
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Filippo Valsorda <filippo@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
This commit is contained in:
Filippo Valsorda 2020-06-24 17:01:00 -04:00
parent 74ff83e750
commit 3e0f07eb2d
6 changed files with 30 additions and 30 deletions

View file

@ -229,9 +229,6 @@ type ConnectionState struct {
CipherSuite uint16 CipherSuite uint16
// NegotiatedProtocol is the application protocol negotiated with ALPN. // NegotiatedProtocol is the application protocol negotiated with ALPN.
//
// Note that on the client side, this is currently not guaranteed to be from
// Config.NextProtos.
NegotiatedProtocol string NegotiatedProtocol string
// NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation.

View file

@ -88,8 +88,8 @@ type Conn struct {
clientFinished [12]byte clientFinished [12]byte
serverFinished [12]byte serverFinished [12]byte
clientProtocol string // clientProtocol is the negotiated ALPN protocol.
clientProtocolFallback bool clientProtocol string
// input/output // input/output
in, out halfConn in, out halfConn
@ -1471,7 +1471,7 @@ func (c *Conn) connectionStateLocked() ConnectionState {
state.Version = c.vers state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume state.DidResume = c.didResume
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.NegotiatedProtocolIsMutual = true
state.ServerName = c.serverName state.ServerName = c.serverName
state.CipherSuite = c.cipherSuite state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates state.PeerCertificates = c.peerCertificates

View file

@ -705,18 +705,18 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
} }
} }
clientDidALPN := len(hs.hello.alpnProtocols) > 0 if hs.serverHello.alpnProtocol != "" {
serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 if len(hs.hello.alpnProtocols) == 0 {
c.sendAlert(alertUnsupportedExtension)
if !clientDidALPN && serverHasALPN { return false, errors.New("tls: server advertised unrequested ALPN extension")
c.sendAlert(alertHandshakeFailure) }
return false, errors.New("tls: server advertised unrequested ALPN extension") if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" {
} c.sendAlert(alertUnsupportedExtension)
return false, errors.New("tls: server selected unadvertised ALPN protocol")
if serverHasALPN { }
c.clientProtocol = hs.serverHello.alpnProtocol c.clientProtocol = hs.serverHello.alpnProtocol
c.clientProtocolFallback = false
} }
c.scts = hs.serverHello.scts c.scts = hs.serverHello.scts
if !hs.serverResumedSession() { if !hs.serverResumedSession() {
@ -973,20 +973,17 @@ func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
return serverAddr.String() return serverAddr.String()
} }
// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol // mutualProtocol finds the mutual ALPN protocol given list of possible
// given list of possible protocols and a list of the preference order. The // protocols and a list of the preference order.
// first list must not be empty. It returns the resulting protocol and flag func mutualProtocol(protos, preferenceProtos []string) string {
// indicating if the fallback case was reached.
func mutualProtocol(protos, preferenceProtos []string) (string, bool) {
for _, s := range preferenceProtos { for _, s := range preferenceProtos {
for _, c := range protos { for _, c := range protos {
if s == c { if s == c {
return s, false return s
} }
} }
} }
return ""
return protos[0], true
} }
// hostnameInSNI converts name into an appropriate hostname for SNI. // hostnameInSNI converts name into an appropriate hostname for SNI.

View file

@ -396,11 +396,17 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
} }
hs.transcript.Write(encryptedExtensions.marshal()) hs.transcript.Write(encryptedExtensions.marshal())
if len(encryptedExtensions.alpnProtocol) != 0 && len(hs.hello.alpnProtocols) == 0 { if encryptedExtensions.alpnProtocol != "" {
c.sendAlert(alertUnsupportedExtension) if len(hs.hello.alpnProtocols) == 0 {
return errors.New("tls: server advertised unrequested ALPN extension") c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server advertised unrequested ALPN extension")
}
if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server selected unadvertised ALPN protocol")
}
c.clientProtocol = encryptedExtensions.alpnProtocol
} }
c.clientProtocol = encryptedExtensions.alpnProtocol
return nil return nil
} }

View file

@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error {
} }
if len(hs.clientHello.alpnProtocols) > 0 { if len(hs.clientHello.alpnProtocols) > 0 {
if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" {
hs.hello.alpnProtocol = selectedProto hs.hello.alpnProtocol = selectedProto
c.clientProtocol = selectedProto c.clientProtocol = selectedProto
} }

View file

@ -555,7 +555,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
encryptedExtensions := new(encryptedExtensionsMsg) encryptedExtensions := new(encryptedExtensionsMsg)
if len(hs.clientHello.alpnProtocols) > 0 { if len(hs.clientHello.alpnProtocols) > 0 {
if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" {
encryptedExtensions.alpnProtocol = selectedProto encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto c.clientProtocol = selectedProto
} }