crypto/tls: drop macFunction abstraction

Since we dropped SSLv3, there is only one MAC scheme, and it doesn't
need any state beyond a keyed HMAC, so we can replace the macFunction
with the hash.Hash it wraps.

Pointed out by mtp@.

Change-Id: I5545be0e6ccb34a3055fad7f6cb5f628ff748e9f
Reviewed-on: https://go-review.googlesource.com/c/go/+/251859
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Trust: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Filippo Valsorda 2020-08-31 17:09:57 -04:00
parent 3b66a0b37e
commit 74ff83e750
4 changed files with 45 additions and 66 deletions

View file

@ -160,7 +160,7 @@ type cipherSuite struct {
// flags is a bitmask of the suite* values, above. // flags is a bitmask of the suite* values, above.
flags int flags int
cipher func(key, iv []byte, isRead bool) interface{} cipher func(key, iv []byte, isRead bool) interface{}
mac func(version uint16, macKey []byte) macFunction mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead aead func(key, fixedNonce []byte) aead
} }
@ -247,24 +247,15 @@ func cipherAES(key, iv []byte, isRead bool) interface{} {
return cipher.NewCBCEncrypter(block, iv) return cipher.NewCBCEncrypter(block, iv)
} }
// macSHA1 returns a macFunction for the given protocol version. // macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(version uint16, key []byte) macFunction { func macSHA1(key []byte) hash.Hash {
return tls10MAC{h: hmac.New(newConstantTimeHash(sha1.New), key)} return hmac.New(newConstantTimeHash(sha1.New), key)
} }
// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 // macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// so the given version is ignored. // is currently only used in disabled-by-default cipher suites.
func macSHA256(version uint16, key []byte) macFunction { func macSHA256(key []byte) hash.Hash {
return tls10MAC{h: hmac.New(sha256.New, key)} return hmac.New(sha256.New, key)
}
type macFunction interface {
// Size returns the length of the MAC.
Size() int
// MAC appends the MAC of (seq, header, data) to out. The extra data is fed
// into the MAC after obtaining the result to normalize timing. The result
// is only valid until the next invocation of MAC as the buffer is reused.
MAC(seq, header, data, extra []byte) []byte
} }
type aead interface { type aead interface {
@ -412,26 +403,14 @@ func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
} }
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
type tls10MAC struct { func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h hash.Hash h.Reset()
buf []byte h.Write(seq)
} h.Write(header)
h.Write(data)
func (s tls10MAC) Size() int { res := h.Sum(out)
return s.h.Size()
}
// MAC is guaranteed to take constant time, as long as
// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into
// the MAC, but is only provided to make the timing profile constant.
func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte {
s.h.Reset()
s.h.Write(seq)
s.h.Write(header)
s.h.Write(data)
res := s.h.Sum(s.buf[:0])
if extra != nil { if extra != nil {
s.h.Write(extra) h.Write(extra)
} }
return res return res
} }

46
conn.go
View file

@ -14,6 +14,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"net" "net"
"sync" "sync"
@ -155,15 +156,16 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
type halfConn struct { type halfConn struct {
sync.Mutex sync.Mutex
err error // first permanent error err error // first permanent error
version uint16 // protocol version version uint16 // protocol version
cipher interface{} // cipher algorithm cipher interface{} // cipher algorithm
mac macFunction mac hash.Hash
seq [8]byte // 64-bit sequence number seq [8]byte // 64-bit sequence number
additionalData [13]byte // to avoid allocs; interface method args escape
scratchBuf [13]byte // to avoid allocs; interface method args escape
nextCipher interface{} // next encryption state nextCipher interface{} // next encryption state
nextMac macFunction // next MAC algorithm nextMac hash.Hash // next MAC algorithm
trafficSecret []byte // current TLS 1.3 traffic secret trafficSecret []byte // current TLS 1.3 traffic secret
} }
@ -188,7 +190,7 @@ func (hc *halfConn) setErrorLocked(err error) error {
// prepareCipherSpec sets the encryption and MAC states // prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use. // that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
hc.version = version hc.version = version
hc.nextCipher = cipher hc.nextCipher = cipher
hc.nextMac = mac hc.nextMac = mac
@ -350,15 +352,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
} }
payload = payload[explicitNonceLen:] payload = payload[explicitNonceLen:]
additionalData := hc.additionalData[:] var additionalData []byte
if hc.version == VersionTLS13 { if hc.version == VersionTLS13 {
additionalData = record[:recordHeaderLen] additionalData = record[:recordHeaderLen]
} else { } else {
copy(additionalData, hc.seq[:]) additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
copy(additionalData[8:], record[:3]) additionalData = append(additionalData, record[:3]...)
n := len(payload) - c.Overhead() n := len(payload) - c.Overhead()
additionalData[11] = byte(n >> 8) additionalData = append(additionalData, byte(n>>8), byte(n))
additionalData[12] = byte(n)
} }
var err error var err error
@ -424,7 +425,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
record[3] = byte(n >> 8) record[3] = byte(n >> 8)
record[4] = byte(n) record[4] = byte(n)
remoteMAC := payload[n : n+macSize] remoteMAC := payload[n : n+macSize]
localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
// This is equivalent to checking the MACs and paddingGood // This is equivalent to checking the MACs and paddingGood
// separately, but in constant-time to prevent distinguishing // separately, but in constant-time to prevent distinguishing
@ -460,7 +461,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) {
} }
// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
// appends it to record, which contains the record header. // appends it to record, which must already contain the record header.
func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
if hc.cipher == nil { if hc.cipher == nil {
return append(record, payload...), nil return append(record, payload...), nil
@ -477,7 +478,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
// an 8 bytes nonce but its nonces must be unpredictable (see RFC // an 8 bytes nonce but its nonces must be unpredictable (see RFC
// 5246, Appendix F.3), forcing us to use randomness. That's not // 5246, Appendix F.3), forcing us to use randomness. That's not
// 3DES' biggest problem anyway because the birthday bound on block // 3DES' biggest problem anyway because the birthday bound on block
// collision is reached first due to its simlarly small block size // collision is reached first due to its similarly small block size
// (see the Sweet32 attack). // (see the Sweet32 attack).
copy(explicitNonce, hc.seq[:]) copy(explicitNonce, hc.seq[:])
} else { } else {
@ -487,14 +488,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
} }
} }
var mac []byte
if hc.mac != nil {
mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
}
var dst []byte var dst []byte
switch c := hc.cipher.(type) { switch c := hc.cipher.(type) {
case cipher.Stream: case cipher.Stream:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
record, dst = sliceForAppend(record, len(payload)+len(mac)) record, dst = sliceForAppend(record, len(payload)+len(mac))
c.XORKeyStream(dst[:len(payload)], payload) c.XORKeyStream(dst[:len(payload)], payload)
c.XORKeyStream(dst[len(payload):], mac) c.XORKeyStream(dst[len(payload):], mac)
@ -518,11 +515,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
record = c.Seal(record[:recordHeaderLen], record = c.Seal(record[:recordHeaderLen],
nonce, record[recordHeaderLen:], record[:recordHeaderLen]) nonce, record[recordHeaderLen:], record[:recordHeaderLen])
} else { } else {
copy(hc.additionalData[:], hc.seq[:]) additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
copy(hc.additionalData[8:], record) additionalData = append(additionalData, record[:recordHeaderLen]...)
record = c.Seal(record, nonce, payload, hc.additionalData[:]) record = c.Seal(record, nonce, payload, additionalData)
} }
case cbcMode: case cbcMode:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
blockSize := c.BlockSize() blockSize := c.BlockSize()
plaintextLen := len(payload) + len(mac) plaintextLen := len(payload) + len(mac)
paddingLen := blockSize - plaintextLen%blockSize paddingLen := blockSize - plaintextLen%blockSize

View file

@ -15,6 +15,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"net" "net"
"strings" "strings"
@ -651,12 +652,12 @@ func (hs *clientHandshakeState) establishKeys() error {
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher interface{} var clientCipher, serverCipher interface{}
var clientHash, serverHash macFunction var clientHash, serverHash hash.Hash
if hs.suite.cipher != nil { if hs.suite.cipher != nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
clientHash = hs.suite.mac(c.vers, clientMAC) clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
serverHash = hs.suite.mac(c.vers, serverMAC) serverHash = hs.suite.mac(serverMAC)
} else { } else {
clientCipher = hs.suite.aead(clientKey, clientIV) clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV) serverCipher = hs.suite.aead(serverKey, serverIV)

View file

@ -14,6 +14,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"sync/atomic" "sync/atomic"
"time" "time"
@ -645,13 +646,13 @@ func (hs *serverHandshakeState) establishKeys() error {
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher interface{} var clientCipher, serverCipher interface{}
var clientHash, serverHash macFunction var clientHash, serverHash hash.Hash
if hs.suite.aead == nil { if hs.suite.aead == nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
clientHash = hs.suite.mac(c.vers, clientMAC) clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = hs.suite.mac(c.vers, serverMAC) serverHash = hs.suite.mac(serverMAC)
} else { } else {
clientCipher = hs.suite.aead(clientKey, clientIV) clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV) serverCipher = hs.suite.aead(serverKey, serverIV)