use a prefix to distinguish IPs and net.Addrs in source address tokens

This commit is contained in:
Marten Seemann 2017-05-24 01:04:19 +08:00
parent 87df63dd5f
commit afc9b11715
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
2 changed files with 13 additions and 13 deletions

View file

@ -173,8 +173,6 @@ var _ = Describe("Server Crypto Setup", func() {
BeforeEach(func() {
var err error
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
validSTK, err = mockStkSource{}.NewToken(remoteAddr.IP)
Expect(err).NotTo(HaveOccurred())
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
aeadChanged = make(chan protocol.EncryptionLevel, 2)
@ -206,6 +204,8 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.stkGenerator.stkSource = &mockStkSource{}
validSTK, err = cs.stkGenerator.NewToken(remoteAddr)
Expect(err).NotTo(HaveOccurred())
sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid }
cs.keyDerivation = mockKeyDerivation

View file

@ -1,13 +1,17 @@
package handshake
import (
"bytes"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a source address token
type STK struct {
RemoteAddr string
@ -52,21 +56,17 @@ func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) {
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
// it ensures that we're binary compatible with Google's implementation of STKs
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
// if the address is a UDP address, just use the byte representation of the IP address
// the length of an IP address is 4 bytes (for IPv4) or 16 bytes (for IPv6)
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return udpAddr.IP
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
// if the address is not a UDP address, prepend 16 bytes
// that way it can be distinguished from an IP address
return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...)
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
if len(data) <= 16 {
return net.IP(data).String()
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[16:])
return string(data[1:])
}