add a quic.Config option to verify source address tokes

This commit is contained in:
Marten Seemann 2017-05-21 00:28:31 +08:00
parent eb72b494b2
commit 87df63dd5f
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
9 changed files with 245 additions and 82 deletions

View file

@ -8,7 +8,6 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
@ -33,6 +32,8 @@ type cryptoSetupServer struct {
version protocol.VersionNumber version protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *STK) bool
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
secureAEAD crypto.AEAD secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
@ -67,6 +68,7 @@ func NewCryptoSetup(
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
connectionParametersManager ConnectionParametersManager, connectionParametersManager ConnectionParametersManager,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *STK) bool,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
stkGenerator, err := NewSTKGenerator() stkGenerator, err := NewSTKGenerator()
@ -86,6 +88,7 @@ func NewCryptoSetup(
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
connectionParameters: connectionParametersManager, connectionParameters: connectionParametersManager,
acceptSTKCallback: acceptSTK,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
}, nil }, nil
} }
@ -272,19 +275,16 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
if crypto.HashCert(cert) != xlct { if crypto.HashCert(cert) != xlct {
return true return true
} }
return !h.verifySTK(cryptoData[TagSTK]) return !h.acceptSTK(cryptoData[TagSTK])
} }
func (h *cryptoSetupServer) verifySTK(stk []byte) bool { func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stkTime, err := h.stkGenerator.VerifyToken(h.remoteAddr, stk) stk, err := h.stkGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("STK invalid: %s", err.Error()) utils.Debugf("STK invalid: %s", err.Error())
return false return false
} }
if time.Now().After(stkTime.Add(protocol.STKExpiryTime)) { return h.acceptSTKCallback(h.remoteAddr, stk)
return false
}
return true
} }
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
@ -303,7 +303,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
TagSVID: []byte("quic-go"), TagSVID: []byte("quic-go"),
} }
if h.verifySTK(cryptoData[TagSTK]) { if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo) proof, err := h.scfg.Sign(sni, chlo)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -167,6 +167,7 @@ var _ = Describe("Server Crypto Setup", func() {
kexs []byte kexs []byte
version protocol.VersionNumber version protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
sourceAddrValid bool
) )
BeforeEach(func() { BeforeEach(func() {
@ -199,11 +200,14 @@ var _ = Describe("Server Crypto Setup", func() {
stream, stream,
cpm, cpm,
supportedVersions, supportedVersions,
nil,
aeadChanged, aeadChanged,
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer) cs = csInt.(*cryptoSetupServer)
cs.stkGenerator.stkSource = &mockStkSource{} cs.stkGenerator.stkSource = &mockStkSource{}
sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid }
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
}) })
@ -264,14 +268,18 @@ var _ = Describe("Server Crypto Setup", func() {
}) })
It("generates REJ messages", func() { It("generates REJ messages", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil) response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ")) Expect(response).To(HavePrefix("REJ"))
Expect(response).To(ContainSubstring("initial public")) Expect(response).To(ContainSubstring("initial public"))
Expect(response).ToNot(ContainSubstring("certcompressed"))
Expect(response).ToNot(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeFalse()) Expect(signer.gotCHLO).To(BeFalse())
}) })
It("REJ messages don't include cert or proof without STK", func() { It("REJ messages don't include cert or proof without STK", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil) response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ")) Expect(response).To(HavePrefix("REJ"))
@ -281,6 +289,7 @@ var _ = Describe("Server Crypto Setup", func() {
}) })
It("REJ messages include cert and proof with valid STK", func() { It("REJ messages include cert and proof with valid STK", func() {
sourceAddrValid = true
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
TagSTK: validSTK, TagSTK: validSTK,
TagSNI: []byte("foo"), TagSNI: []byte("foo"),
@ -400,11 +409,6 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
}) })
It("REJ messages that have an expired STK", func() {
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes proper CHLOs", func() { It("recognizes proper CHLOs", func() {
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse()) Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
}) })
@ -690,6 +694,7 @@ var _ = Describe("Server Crypto Setup", func() {
Context("STK verification and creation", func() { Context("STK verification and creation", func() {
It("requires STK", func() { It("requires STK", func() {
sourceAddrValid = false
done, err := cs.handleMessage( done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{ map[Tag][]byte{
@ -703,10 +708,10 @@ var _ = Describe("Server Crypto Setup", func() {
}) })
It("works with proper STK", func() { It("works with proper STK", func() {
sourceAddrValid = true
done, err := cs.handleMessage( done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{ map[Tag][]byte{
TagSTK: validSTK,
TagSNI: []byte("foo"), TagSNI: []byte("foo"),
TagVER: versionTag, TagVER: versionTag,
}, },
@ -714,19 +719,5 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse()) Expect(done).To(BeFalse())
}) })
It("errors if IP does not match", func() {
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagSTK: []byte("token \x04\x03\x03\x01"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
})
}) })
}) })

View file

@ -2,14 +2,18 @@ package handshake
import ( import (
"bytes" "bytes"
"crypto/subtle"
"errors"
"net" "net"
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
) )
// An STK is a source address token
type STK struct {
RemoteAddr string
SentTime time.Time
}
// An STKGenerator generates STKs // An STKGenerator generates STKs
type STKGenerator struct { type STKGenerator struct {
stkSource crypto.StkSource stkSource crypto.StkSource
@ -31,16 +35,20 @@ func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
return g.stkSource.NewToken(encodeRemoteAddr(raddr)) return g.stkSource.NewToken(encodeRemoteAddr(raddr))
} }
// VerifyToken verifies an STK token // DecodeToken decodes an STK token
func (g *STKGenerator) VerifyToken(raddr net.Addr, data []byte) (time.Time, error) { func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) {
data, timestamp, err := g.stkSource.DecodeToken(data) // if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(data) == 0 {
return nil, nil
}
remote, timestamp, err := g.stkSource.DecodeToken(data)
if err != nil { if err != nil {
return time.Time{}, err return nil, err
} }
if subtle.ConstantTimeCompare(encodeRemoteAddr(raddr), data) != 1 { return &STK{
return time.Time{}, errors.New("invalid source address in STK") RemoteAddr: decodeRemoteAddr(remote),
} SentTime: timestamp,
return timestamp, nil }, nil
} }
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK // encodeRemoteAddr encodes a remote address such that it can be saved in the STK
@ -55,3 +63,10 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
// that way it can be distinguished from an IP address // that way it can be distinguished from an IP address
return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...) return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...)
} }
func decodeRemoteAddr(data []byte) string {
if len(data) <= 16 {
return net.IP(data).String()
}
return string(data[16:])
}

View file

@ -24,58 +24,49 @@ var _ = Describe("STK Generator", func() {
Expect(token).ToNot(BeEmpty()) Expect(token).ToNot(BeEmpty())
}) })
It("works with nil tokens", func() {
stk, err := stkGen.DecodeToken(nil)
Expect(err).ToNot(HaveOccurred())
Expect(stk).To(BeNil())
})
It("accepts a valid STK", func() { It("accepts a valid STK", func() {
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
token, err := stkGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
t, err := stkGen.VerifyToken(raddr, token)
Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
})
It("works with an IPv6 address", func() {
ip := net.ParseIP("2001:db8::68")
Expect(ip).ToNot(BeNil())
raddr := &net.UDPAddr{IP: ip, Port: 1337}
token, err := stkGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
t, err := stkGen.VerifyToken(raddr, token)
Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
})
It("does not care about the port", func() {
ip := net.IPv4(192, 168, 0, 1) ip := net.IPv4(192, 168, 0, 1)
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: ip, Port: 7331}, token) stk, err := stkGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(stk.RemoteAddr).To(Equal("192.168.0.1"))
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
}) })
It("rejects an STK for the wrong address", func() { It("works with an IPv6 addresses ", func() {
ip := net.ParseIP("1.2.3.4") addresses := []string{
otherIP := net.ParseIP("4.3.2.1") "2001:db8::68",
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) "2001:0000:4136:e378:8000:63bf:3fff:fdd2",
Expect(err).NotTo(HaveOccurred()) "2001::1",
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: otherIP, Port: 1337}, token) "ff01:0:0:0:0:0:0:2",
Expect(err).To(MatchError("invalid source address in STK")) }
for _, addr := range addresses {
ip := net.ParseIP(addr)
Expect(ip).ToNot(BeNil())
raddr := &net.UDPAddr{IP: ip, Port: 1337}
token, err := stkGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
stk, err := stkGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(stk.RemoteAddr).To(Equal(ip.String()))
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
}
}) })
It("works with an address that is not a UDP address", func() { It("uses the string representation an address that is not a UDP address", func() {
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := stkGen.NewToken(raddr) token, err := stkGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
t, err := stkGen.VerifyToken(raddr, token) stk, err := stkGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", time.Now(), time.Second)) Expect(stk.RemoteAddr).To(Equal("192.168.13.37:1337"))
}) Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
It("uses the string representation of an address that is not a UDP address", func() {
// when using the string representation, the port matters
ip := net.IPv4(192, 168, 0, 1)
token, err := stkGen.NewToken(&net.TCPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred())
_, err = stkGen.VerifyToken(&net.TCPAddr{IP: ip, Port: 7331}, token)
Expect(err).To(MatchError("invalid source address in STK"))
}) })
}) })

View file

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"io" "io"
"net" "net"
"time"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
@ -45,6 +46,18 @@ type NonFWSession interface {
WaitUntilHandshakeComplete() error WaitUntilHandshakeComplete() error
} }
// An STK is a Source Address token.
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
type STK struct {
// The remote address this token was issued for.
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
remoteAddr string
// The time that the STK was issued (resolution 1 second)
sentTime time.Time
}
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct { type Config struct {
@ -54,9 +67,15 @@ type Config struct {
// Warning: This API should not be considered stable and will change soon. // Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header. // Ask the server to truncate the connection ID sent in the Public Header.
// If not set, the default checks if
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client. // Currently only valid for the client.
RequestConnectionIDTruncation bool RequestConnectionIDTruncation bool
// AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours
// This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
} }
// A Listener for incoming QUIC connections // A Listener for incoming QUIC connections

View file

@ -85,15 +85,36 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
return s, nil return s, nil
} }
var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
if stk == nil {
return false
}
if time.Now().After(stk.sentTime.Add(protocol.STKExpiryTime)) {
return false
}
var sourceAddr string
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP.String()
} else {
sourceAddr = clientAddr.String()
}
return sourceAddr == stk.remoteAddr
}
func populateServerConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
vsa := defaultAcceptSTK
if config.AcceptSTK != nil {
vsa = config.AcceptSTK
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig, TLSConfig: config.TLSConfig,
Versions: versions, Versions: versions,
AcceptSTK: vsa,
} }
} }
@ -116,7 +137,7 @@ func (s *server) serve() {
utils.Errorf("error handling packet: %s", err.Error()) utils.Errorf("error handling packet: %s", err.Error())
} }
} }
} }
// Accept returns newly openend sessions // Accept returns newly openend sessions
func (s *server) Accept() (Session, error) { func (s *server) Accept() (Session, error) {

View file

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"reflect"
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
@ -342,17 +343,20 @@ var _ = Describe("Server", func() {
It("setups with the right values", func() { It("setups with the right values", func() {
supportedVersions := []protocol.VersionNumber{1, 3, 5} supportedVersions := []protocol.VersionNumber{1, 3, 5}
acceptSTK := func(_ net.Addr, _ *STK) bool { return true }
config := Config{ config := Config{
TLSConfig: &tls.Config{}, TLSConfig: &tls.Config{},
Versions: supportedVersions, Versions: supportedVersions,
AcceptSTK: acceptSTK,
} }
ln, err := Listen(conn, &config) ln, err := Listen(conn, &config)
server := ln.(*server)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
Expect(server.sessions).ToNot(BeNil()) Expect(server.sessions).ToNot(BeNil())
Expect(server.scfg).ToNot(BeNil()) Expect(server.scfg).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(acceptSTK)))
}) })
It("fills in default values if options are not set in the Config", func() { It("fills in default values if options are not set in the Config", func() {
@ -361,6 +365,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
server := ln.(*server) server := ln.(*server)
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(defaultAcceptSTK)))
}) })
It("listens on a given address", func() { It("listens on a given address", func() {
@ -434,3 +439,55 @@ var _ = Describe("Server", func() {
Expect(ln.(*server).sessions).To(BeEmpty()) Expect(ln.(*server).sessions).To(BeEmpty())
}) })
}) })
var _ = Describe("default source address verification", func() {
It("accepts a token", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
stk := &STK{
remoteAddr: "192.168.0.1",
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(time.Second), // will expire in 1 second
}
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
})
It("requests verification if no token is provided", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
Expect(defaultAcceptSTK(remoteAddr, nil)).To(BeFalse())
})
It("rejects a token if the address doesn't match", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
stk := &STK{
remoteAddr: "127.0.0.1",
sentTime: time.Now(),
}
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
})
It("accepts a token for a remote address is not a UDP address", func() {
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
stk := &STK{
remoteAddr: "192.168.0.1:1337",
sentTime: time.Now(),
}
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
})
It("rejects an invalid token for a remote address is not a UDP address", func() {
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
stk := &STK{
remoteAddr: "192.168.0.1:7331", // mismatching port
sentTime: time.Now(),
}
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
})
It("rejects an expired token", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
stk := &STK{
remoteAddr: "192.168.0.1",
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second), // expired 1 second ago
}
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
})
})

View file

@ -141,6 +141,15 @@ func newSession(
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3) handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan s.handshakeChan = handshakeChan
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
if hstk == nil {
return config.AcceptSTK(clientAddr, nil)
}
return config.AcceptSTK(
clientAddr,
&STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime},
)
}
var err error var err error
s.cryptoSetup, err = newCryptoSetup( s.cryptoSetup, err = newCryptoSetup(
connectionID, connectionID,
@ -150,6 +159,7 @@ func newSession(
cryptoStream, cryptoStream,
s.connectionParameters, s.connectionParameters,
config.Versions, config.Versions,
verifySourceAddr,
aeadChanged, aeadChanged,
) )
if err != nil { if err != nil {

View file

@ -145,6 +145,7 @@ var _ = Describe("Session", func() {
_ io.ReadWriter, _ io.ReadWriter,
_ handshake.ConnectionParametersManager, _ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
_ func(net.Addr, *handshake.STK) bool,
aeadChangedP chan<- protocol.EncryptionLevel, aeadChangedP chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP aeadChanged = aeadChangedP
@ -180,6 +181,64 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
}) })
Context("source address validation", func() {
var (
stkVerify func(net.Addr, *handshake.STK) bool
paramClientAddr net.Addr
paramSTK *STK
)
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1000}
BeforeEach(func() {
newCryptoSetup = func(
_ protocol.ConnectionID,
_ net.Addr,
_ protocol.VersionNumber,
_ *handshake.ServerConfig,
_ io.ReadWriter,
_ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber,
stkFunc func(net.Addr, *handshake.STK) bool,
_ chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) {
stkVerify = stkFunc
return cryptoSetup, nil
}
conf := populateServerConfig(&Config{})
conf.AcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
paramClientAddr = clientAddr
paramSTK = stk
return false
}
pSess, _, err := newSession(
mconn,
protocol.Version35,
0,
scfg,
conf,
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
})
It("calls the callback with the right parameters when the client didn't send an STK", func() {
stkVerify(remoteAddr, nil)
Expect(paramClientAddr).To(Equal(remoteAddr))
Expect(paramSTK).To(BeNil())
})
It("calls the callback with the STK when the client sent an STK", func() {
stkAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
sentTime := time.Now().Add(-time.Hour)
stkVerify(remoteAddr, &handshake.STK{SentTime: sentTime, RemoteAddr: stkAddr.String()})
Expect(paramClientAddr).To(Equal(remoteAddr))
Expect(paramSTK).ToNot(BeNil())
Expect(paramSTK.remoteAddr).To(Equal(stkAddr.String()))
Expect(paramSTK.sentTime).To(Equal(sentTime))
})
})
Context("when handling stream frames", func() { Context("when handling stream frames", func() {
It("makes new streams", func() { It("makes new streams", func() {
sess.handleStreamFrame(&frames.StreamFrame{ sess.handleStreamFrame(&frames.StreamFrame{