From afc9b11715df1c57f61d75f6f6147148e492f9e6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 24 May 2017 01:04:19 +0800 Subject: [PATCH] use a prefix to distinguish IPs and net.Addrs in source address tokens --- handshake/crypto_setup_server_test.go | 4 ++-- handshake/stk_generator.go | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index bd65e732..764f262b 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -173,8 +173,6 @@ var _ = Describe("Server Crypto Setup", func() { BeforeEach(func() { var err error remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - validSTK, err = mockStkSource{}.NewToken(remoteAddr.IP) - Expect(err).NotTo(HaveOccurred()) expectedInitialNonceLen = 32 expectedFSNonceLen = 64 aeadChanged = make(chan protocol.EncryptionLevel, 2) @@ -206,6 +204,8 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) cs.stkGenerator.stkSource = &mockStkSource{} + validSTK, err = cs.stkGenerator.NewToken(remoteAddr) + Expect(err).NotTo(HaveOccurred()) sourceAddrValid = true cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid } cs.keyDerivation = mockKeyDerivation diff --git a/handshake/stk_generator.go b/handshake/stk_generator.go index ea71b12e..cf803082 100644 --- a/handshake/stk_generator.go +++ b/handshake/stk_generator.go @@ -1,13 +1,17 @@ package handshake import ( - "bytes" "net" "time" "github.com/lucas-clemente/quic-go/crypto" ) +const ( + stkPrefixIP byte = iota + stkPrefixString +) + // An STK is a source address token type STK struct { RemoteAddr string @@ -52,21 +56,17 @@ func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) { } // encodeRemoteAddr encodes a remote address such that it can be saved in the STK -// it ensures that we're binary compatible with Google's implementation of STKs func encodeRemoteAddr(remoteAddr net.Addr) []byte { - // if the address is a UDP address, just use the byte representation of the IP address - // the length of an IP address is 4 bytes (for IPv4) or 16 bytes (for IPv6) if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return udpAddr.IP + return append([]byte{stkPrefixIP}, udpAddr.IP...) } - // if the address is not a UDP address, prepend 16 bytes - // that way it can be distinguished from an IP address - return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...) + return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) } +// decodeRemoteAddr decodes the remote address saved in the STK func decodeRemoteAddr(data []byte) string { - if len(data) <= 16 { - return net.IP(data).String() + if data[0] == stkPrefixIP { + return net.IP(data[1:]).String() } - return string(data[16:]) + return string(data[1:]) }