diff --git a/common.go b/common.go index 1370d26..98b31b0 100644 --- a/common.go +++ b/common.go @@ -229,9 +229,6 @@ type ConnectionState struct { CipherSuite uint16 // NegotiatedProtocol is the application protocol negotiated with ALPN. - // - // Note that on the client side, this is currently not guaranteed to be from - // Config.NextProtos. NegotiatedProtocol string // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. diff --git a/conn.go b/conn.go index 2788c3c..969f357 100644 --- a/conn.go +++ b/conn.go @@ -88,8 +88,8 @@ type Conn struct { clientFinished [12]byte serverFinished [12]byte - clientProtocol string - clientProtocolFallback bool + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string // input/output in, out halfConn @@ -1471,7 +1471,7 @@ func (c *Conn) connectionStateLocked() ConnectionState { state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.NegotiatedProtocolIsMutual = true state.ServerName = c.serverName state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates diff --git a/handshake_client.go b/handshake_client.go index 123df7b..92e33e7 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -705,18 +705,18 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { } } - clientDidALPN := len(hs.hello.alpnProtocols) > 0 - serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 - - if !clientDidALPN && serverHasALPN { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server advertised unrequested ALPN extension") - } - - if serverHasALPN { + if hs.serverHello.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server selected unadvertised ALPN protocol") + } c.clientProtocol = hs.serverHello.alpnProtocol - c.clientProtocolFallback = false } + c.scts = hs.serverHello.scts if !hs.serverResumedSession() { @@ -973,20 +973,17 @@ func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { return serverAddr.String() } -// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol -// given list of possible protocols and a list of the preference order. The -// first list must not be empty. It returns the resulting protocol and flag -// indicating if the fallback case was reached. -func mutualProtocol(protos, preferenceProtos []string) (string, bool) { +// mutualProtocol finds the mutual ALPN protocol given list of possible +// protocols and a list of the preference order. +func mutualProtocol(protos, preferenceProtos []string) string { for _, s := range preferenceProtos { for _, c := range protos { if s == c { - return s, false + return s } } } - - return protos[0], true + return "" } // hostnameInSNI converts name into an appropriate hostname for SNI. diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 0e4b380..be37c68 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -396,11 +396,17 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { } hs.transcript.Write(encryptedExtensions.marshal()) - if len(encryptedExtensions.alpnProtocol) != 0 && len(hs.hello.alpnProtocols) == 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server advertised unrequested ALPN extension") + if encryptedExtensions.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server selected unadvertised ALPN protocol") + } + c.clientProtocol = encryptedExtensions.alpnProtocol } - c.clientProtocol = encryptedExtensions.alpnProtocol return nil } diff --git a/handshake_server.go b/handshake_server.go index 73df19d..a7d4414 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error { } if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { hs.hello.alpnProtocol = selectedProto c.clientProtocol = selectedProto } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 25c37b9..41f7ac2 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -555,7 +555,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions := new(encryptedExtensionsMsg) if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { encryptedExtensions.alpnProtocol = selectedProto c.clientProtocol = selectedProto }