rename the STKGenerator to CookieGenerator

This commit is contained in:
Marten Seemann 2017-09-11 18:10:58 +02:00
parent c78a4b2b73
commit 14fae7b6d3
7 changed files with 150 additions and 155 deletions

View file

@ -17,7 +17,7 @@ type StreamID = protocol.StreamID
type VersionNumber = protocol.VersionNumber
// An STK can be used to verify the ownership of the client address.
type STK = handshake.STK
type STK = handshake.Cookie
// Stream is the interface implemented by QUIC streams
type Stream interface {

View file

@ -0,0 +1,101 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
const (
cookiePrefixIP byte = iota
cookiePrefixString
)
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
type Cookie struct {
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// A CookieGenerator generates Cookies
type CookieGenerator struct {
cookieSource crypto.StkSource
}
// NewCookieGenerator initializes a new CookieGenerator
func NewCookieGenerator() (*CookieGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &CookieGenerator{
cookieSource: stkSource,
}, nil
}
// NewToken generates a new Cookie for a given source address
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.cookieSource.NewToken(data)
}
// DecodeToken decodes a Cookie
func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
// if the client didn't send any Cookie, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.cookieSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &Cookie{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{cookiePrefixIP}, udpAddr.IP...)
}
return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the Cookie
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a Cookie that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == cookiePrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -9,49 +9,49 @@ import (
. "github.com/onsi/gomega"
)
var _ = Describe("STK Generator", func() {
var stkGen *STKGenerator
var _ = Describe("Cookie Generator", func() {
var cookieGen *CookieGenerator
BeforeEach(func() {
var err error
stkGen, err = NewSTKGenerator()
cookieGen, err = NewCookieGenerator()
Expect(err).ToNot(HaveOccurred())
})
It("generates an STK", func() {
It("generates a Cookie", func() {
ip := net.IPv4(127, 0, 0, 1)
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred())
Expect(token).ToNot(BeEmpty())
})
It("works with nil tokens", func() {
stk, err := stkGen.DecodeToken(nil)
cookie, err := cookieGen.DecodeToken(nil)
Expect(err).ToNot(HaveOccurred())
Expect(stk).To(BeNil())
Expect(cookie).To(BeNil())
})
It("accepts a valid STK", func() {
It("accepts a valid cookie", func() {
ip := net.IPv4(192, 168, 0, 1)
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred())
stk, err := stkGen.DecodeToken(token)
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(stk.RemoteAddr).To(Equal("192.168.0.1"))
// the time resolution of the STK is just 1 second
// if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
Expect(cookie.RemoteAddr).To(Equal("192.168.0.1"))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
})
It("rejects invalid tokens", func() {
_, err := stkGen.DecodeToken([]byte("invalid token"))
_, err := cookieGen.DecodeToken([]byte("invalid token"))
Expect(err).To(HaveOccurred())
})
It("rejects tokens that cannot be decoded", func() {
token, err := stkGen.stkSource.NewToken([]byte("foobar"))
token, err := cookieGen.cookieSource.NewToken([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
_, err = stkGen.DecodeToken(token)
_, err = cookieGen.DecodeToken(token)
Expect(err).To(HaveOccurred())
})
@ -59,9 +59,9 @@ var _ = Describe("STK Generator", func() {
t, err := asn1.Marshal(token{Data: []byte("foobar")})
Expect(err).ToNot(HaveOccurred())
t = append(t, []byte("rest")...)
enc, err := stkGen.stkSource.NewToken(t)
enc, err := cookieGen.cookieSource.NewToken(t)
Expect(err).ToNot(HaveOccurred())
_, err = stkGen.DecodeToken(enc)
_, err = cookieGen.DecodeToken(enc)
Expect(err).To(MatchError("rest when unpacking token: 4"))
})
@ -69,9 +69,9 @@ var _ = Describe("STK Generator", func() {
It("doesn't panic if a tokens has no data", func() {
t, err := asn1.Marshal(token{Data: []byte("")})
Expect(err).ToNot(HaveOccurred())
enc, err := stkGen.stkSource.NewToken(t)
enc, err := cookieGen.cookieSource.NewToken(t)
Expect(err).ToNot(HaveOccurred())
_, err = stkGen.DecodeToken(enc)
_, err = cookieGen.DecodeToken(enc)
Expect(err).ToNot(HaveOccurred())
})
@ -86,26 +86,26 @@ var _ = Describe("STK Generator", func() {
ip := net.ParseIP(addr)
Expect(ip).ToNot(BeNil())
raddr := &net.UDPAddr{IP: ip, Port: 1337}
token, err := stkGen.NewToken(raddr)
token, err := cookieGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
stk, err := stkGen.DecodeToken(token)
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(stk.RemoteAddr).To(Equal(ip.String()))
// the time resolution of the STK is just 1 second
// if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
Expect(cookie.RemoteAddr).To(Equal(ip.String()))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
}
})
It("uses the string representation an address that is not a UDP address", func() {
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := stkGen.NewToken(raddr)
token, err := cookieGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
stk, err := stkGen.DecodeToken(token)
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(stk.RemoteAddr).To(Equal("192.168.13.37:1337"))
// the time resolution of the STK is just 1 second
// if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
Expect(cookie.RemoteAddr).To(Equal("192.168.13.37:1337"))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
})
})

View file

@ -26,13 +26,13 @@ type cryptoSetupServer struct {
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
stkGenerator *STKGenerator
stkGenerator *CookieGenerator
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *STK) bool
acceptSTKCallback func(net.Addr, *Cookie) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
@ -72,10 +72,10 @@ func NewCryptoSetup(
cryptoStream io.ReadWriter,
connectionParametersManager ConnectionParametersManager,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *STK) bool,
acceptSTK func(net.Addr, *Cookie) bool,
aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) {
stkGenerator, err := NewSTKGenerator()
stkGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}

View file

@ -131,18 +131,18 @@ func (s *mockStream) Reset(error) { panic("not implemente
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") }
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
type mockStkSource struct {
type mockCookieSource struct {
data []byte
decodeErr error
}
var _ crypto.StkSource = &mockStkSource{}
var _ crypto.StkSource = &mockCookieSource{}
func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) {
func (mockCookieSource) NewToken(sourceAddr []byte) ([]byte, error) {
return append([]byte("token "), sourceAddr...), nil
}
func (s mockStkSource) DecodeToken(data []byte) ([]byte, error) {
func (s mockCookieSource) DecodeToken(data []byte) ([]byte, error) {
if s.decodeErr != nil {
return nil, s.decodeErr
}
@ -209,11 +209,11 @@ var _ = Describe("Server Crypto Setup", func() {
)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.stkGenerator.stkSource = &mockStkSource{}
cs.stkGenerator.cookieSource = &mockCookieSource{}
validSTK, err = cs.stkGenerator.NewToken(remoteAddr)
Expect(err).NotTo(HaveOccurred())
sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid }
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
})
@ -422,7 +422,7 @@ var _ = Describe("Server Crypto Setup", func() {
It("recognizes inchoate CHLOs with an invalid STK", func() {
testErr := errors.New("STK invalid")
cs.stkGenerator.stkSource.(*mockStkSource).decodeErr = testErr
cs.stkGenerator.cookieSource.(*mockCookieSource).decodeErr = testErr
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})

View file

@ -1,106 +0,0 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a Source Address token.
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
type STK struct {
// The remote address this token was issued for.
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.stkSource.NewToken(data)
}
// DecodeToken decodes an STK token
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.stkSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &STK{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
// data will never be empty for an STK that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -166,7 +166,7 @@ var _ = Describe("Session", func() {
_ io.ReadWriter,
_ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber,
_ func(net.Addr, *handshake.STK) bool,
_ func(net.Addr, *STK) bool,
aeadChangedP chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP
@ -204,7 +204,7 @@ var _ = Describe("Session", func() {
Context("source address validation", func() {
var (
stkVerify func(net.Addr, *handshake.STK) bool
stkVerify func(net.Addr, *STK) bool
paramClientAddr net.Addr
paramSTK *STK
)
@ -219,7 +219,7 @@ var _ = Describe("Session", func() {
_ io.ReadWriter,
_ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber,
stkFunc func(net.Addr, *handshake.STK) bool,
stkFunc func(net.Addr, *STK) bool,
_ chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) {
stkVerify = stkFunc
@ -253,7 +253,7 @@ var _ = Describe("Session", func() {
It("calls the callback with the STK when the client sent an STK", func() {
stkAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
sentTime := time.Now().Add(-time.Hour)
stkVerify(remoteAddr, &handshake.STK{SentTime: sentTime, RemoteAddr: stkAddr.String()})
stkVerify(remoteAddr, &STK{SentTime: sentTime, RemoteAddr: stkAddr.String()})
Expect(paramClientAddr).To(Equal(remoteAddr))
Expect(paramSTK).ToNot(BeNil())
Expect(paramSTK.RemoteAddr).To(Equal(stkAddr.String()))