mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
add a quic.Config option to verify source address tokes
This commit is contained in:
parent
eb72b494b2
commit
87df63dd5f
9 changed files with 245 additions and 82 deletions
|
@ -8,7 +8,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
|
@ -33,6 +32,8 @@ type cryptoSetupServer struct {
|
|||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
|
||||
acceptSTKCallback func(net.Addr, *STK) bool
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
|
@ -67,6 +68,7 @@ func NewCryptoSetup(
|
|||
cryptoStream io.ReadWriter,
|
||||
connectionParametersManager ConnectionParametersManager,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *STK) bool,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
) (CryptoSetup, error) {
|
||||
stkGenerator, err := NewSTKGenerator()
|
||||
|
@ -86,6 +88,7 @@ func NewCryptoSetup(
|
|||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||
cryptoStream: cryptoStream,
|
||||
connectionParameters: connectionParametersManager,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
aeadChanged: aeadChanged,
|
||||
}, nil
|
||||
}
|
||||
|
@ -272,19 +275,16 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
|||
if crypto.HashCert(cert) != xlct {
|
||||
return true
|
||||
}
|
||||
return !h.verifySTK(cryptoData[TagSTK])
|
||||
return !h.acceptSTK(cryptoData[TagSTK])
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
||||
stkTime, err := h.stkGenerator.VerifyToken(h.remoteAddr, stk)
|
||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||
stk, err := h.stkGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
if time.Now().After(stkTime.Add(protocol.STKExpiryTime)) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
|
@ -303,7 +303,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
|||
TagSVID: []byte("quic-go"),
|
||||
}
|
||||
|
||||
if h.verifySTK(cryptoData[TagSTK]) {
|
||||
if h.acceptSTK(cryptoData[TagSTK]) {
|
||||
proof, err := h.scfg.Sign(sni, chlo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -167,6 +167,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
kexs []byte
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
sourceAddrValid bool
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -199,11 +200,14 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
stream,
|
||||
cpm,
|
||||
supportedVersions,
|
||||
nil,
|
||||
aeadChanged,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupServer)
|
||||
cs.stkGenerator.stkSource = &mockStkSource{}
|
||||
sourceAddrValid = true
|
||||
cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid }
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
||||
})
|
||||
|
@ -264,14 +268,18 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
})
|
||||
|
||||
It("generates REJ messages", func() {
|
||||
sourceAddrValid = false
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
Expect(response).To(ContainSubstring("initial public"))
|
||||
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
||||
Expect(response).ToNot(ContainSubstring("proof"))
|
||||
Expect(signer.gotCHLO).To(BeFalse())
|
||||
})
|
||||
|
||||
It("REJ messages don't include cert or proof without STK", func() {
|
||||
sourceAddrValid = false
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
|
@ -281,6 +289,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
})
|
||||
|
||||
It("REJ messages include cert and proof with valid STK", func() {
|
||||
sourceAddrValid = true
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: []byte("foo"),
|
||||
|
@ -400,11 +409,6 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("REJ messages that have an expired STK", func() {
|
||||
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes proper CHLOs", func() {
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
||||
})
|
||||
|
@ -690,6 +694,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
|
||||
Context("STK verification and creation", func() {
|
||||
It("requires STK", func() {
|
||||
sourceAddrValid = false
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
|
@ -703,10 +708,10 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
})
|
||||
|
||||
It("works with proper STK", func() {
|
||||
sourceAddrValid = true
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
|
@ -714,19 +719,5 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
})
|
||||
|
||||
It("errors if IP does not match", func() {
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagSTK: []byte("token \x04\x03\x03\x01"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -2,14 +2,18 @@ package handshake
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
)
|
||||
|
||||
// An STK is a source address token
|
||||
type STK struct {
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// An STKGenerator generates STKs
|
||||
type STKGenerator struct {
|
||||
stkSource crypto.StkSource
|
||||
|
@ -31,16 +35,20 @@ func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
|||
return g.stkSource.NewToken(encodeRemoteAddr(raddr))
|
||||
}
|
||||
|
||||
// VerifyToken verifies an STK token
|
||||
func (g *STKGenerator) VerifyToken(raddr net.Addr, data []byte) (time.Time, error) {
|
||||
data, timestamp, err := g.stkSource.DecodeToken(data)
|
||||
// DecodeToken decodes an STK token
|
||||
func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) {
|
||||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
remote, timestamp, err := g.stkSource.DecodeToken(data)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
if subtle.ConstantTimeCompare(encodeRemoteAddr(raddr), data) != 1 {
|
||||
return time.Time{}, errors.New("invalid source address in STK")
|
||||
}
|
||||
return timestamp, nil
|
||||
return &STK{
|
||||
RemoteAddr: decodeRemoteAddr(remote),
|
||||
SentTime: timestamp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
||||
|
@ -55,3 +63,10 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
|||
// that way it can be distinguished from an IP address
|
||||
return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...)
|
||||
}
|
||||
|
||||
func decodeRemoteAddr(data []byte) string {
|
||||
if len(data) <= 16 {
|
||||
return net.IP(data).String()
|
||||
}
|
||||
return string(data[16:])
|
||||
}
|
||||
|
|
|
@ -24,58 +24,49 @@ var _ = Describe("STK Generator", func() {
|
|||
Expect(token).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("accepts a valid STK", func() {
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
token, err := stkGen.NewToken(raddr)
|
||||
It("works with nil tokens", func() {
|
||||
stk, err := stkGen.DecodeToken(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
t, err := stkGen.VerifyToken(raddr, token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||
Expect(stk).To(BeNil())
|
||||
})
|
||||
|
||||
It("works with an IPv6 address", func() {
|
||||
ip := net.ParseIP("2001:db8::68")
|
||||
It("accepts a valid STK", func() {
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
stk, err := stkGen.DecodeToken(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stk.RemoteAddr).To(Equal("192.168.0.1"))
|
||||
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||
})
|
||||
|
||||
It("works with an IPv6 addresses ", func() {
|
||||
addresses := []string{
|
||||
"2001:db8::68",
|
||||
"2001:0000:4136:e378:8000:63bf:3fff:fdd2",
|
||||
"2001::1",
|
||||
"ff01:0:0:0:0:0:0:2",
|
||||
}
|
||||
for _, addr := range addresses {
|
||||
ip := net.ParseIP(addr)
|
||||
Expect(ip).ToNot(BeNil())
|
||||
raddr := &net.UDPAddr{IP: ip, Port: 1337}
|
||||
token, err := stkGen.NewToken(raddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
t, err := stkGen.VerifyToken(raddr, token)
|
||||
stk, err := stkGen.DecodeToken(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||
Expect(stk.RemoteAddr).To(Equal(ip.String()))
|
||||
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||
}
|
||||
})
|
||||
|
||||
It("does not care about the port", func() {
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: ip, Port: 7331}, token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects an STK for the wrong address", func() {
|
||||
ip := net.ParseIP("1.2.3.4")
|
||||
otherIP := net.ParseIP("4.3.2.1")
|
||||
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: otherIP, Port: 1337}, token)
|
||||
Expect(err).To(MatchError("invalid source address in STK"))
|
||||
})
|
||||
|
||||
It("works with an address that is not a UDP address", func() {
|
||||
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
It("uses the string representation an address that is not a UDP address", func() {
|
||||
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||
token, err := stkGen.NewToken(raddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
t, err := stkGen.VerifyToken(raddr, token)
|
||||
stk, err := stkGen.DecodeToken(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||
})
|
||||
|
||||
It("uses the string representation of an address that is not a UDP address", func() {
|
||||
// when using the string representation, the port matters
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
token, err := stkGen.NewToken(&net.TCPAddr{IP: ip, Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = stkGen.VerifyToken(&net.TCPAddr{IP: ip, Port: 7331}, token)
|
||||
Expect(err).To(MatchError("invalid source address in STK"))
|
||||
Expect(stk.RemoteAddr).To(Equal("192.168.13.37:1337"))
|
||||
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||
})
|
||||
})
|
||||
|
|
19
interface.go
19
interface.go
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
@ -45,6 +46,18 @@ type NonFWSession interface {
|
|||
WaitUntilHandshakeComplete() error
|
||||
}
|
||||
|
||||
// An STK is a Source Address token.
|
||||
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
||||
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
||||
type STK struct {
|
||||
// The remote address this token was issued for.
|
||||
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
||||
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
||||
remoteAddr string
|
||||
// The time that the STK was issued (resolution 1 second)
|
||||
sentTime time.Time
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
||||
type Config struct {
|
||||
|
@ -54,9 +67,15 @@ type Config struct {
|
|||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []protocol.VersionNumber
|
||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
||||
// If not set, the default checks if
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDTruncation bool
|
||||
// AcceptSTK determines if an STK is accepted.
|
||||
// It is called with stk = nil if the client didn't send an STK.
|
||||
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours
|
||||
// This option is only valid for the server.
|
||||
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
||||
}
|
||||
|
||||
// A Listener for incoming QUIC connections
|
||||
|
|
21
server.go
21
server.go
|
@ -85,15 +85,36 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
|
||||
if stk == nil {
|
||||
return false
|
||||
}
|
||||
if time.Now().After(stk.sentTime.Add(protocol.STKExpiryTime)) {
|
||||
return false
|
||||
}
|
||||
var sourceAddr string
|
||||
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
|
||||
sourceAddr = udpAddr.IP.String()
|
||||
} else {
|
||||
sourceAddr = clientAddr.String()
|
||||
}
|
||||
return sourceAddr == stk.remoteAddr
|
||||
}
|
||||
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
vsa := defaultAcceptSTK
|
||||
if config.AcceptSTK != nil {
|
||||
vsa = config.AcceptSTK
|
||||
}
|
||||
|
||||
return &Config{
|
||||
TLSConfig: config.TLSConfig,
|
||||
Versions: versions,
|
||||
AcceptSTK: vsa,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
|
@ -342,17 +343,20 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("setups with the right values", func() {
|
||||
supportedVersions := []protocol.VersionNumber{1, 3, 5}
|
||||
acceptSTK := func(_ net.Addr, _ *STK) bool { return true }
|
||||
config := Config{
|
||||
TLSConfig: &tls.Config{},
|
||||
Versions: supportedVersions,
|
||||
AcceptSTK: acceptSTK,
|
||||
}
|
||||
ln, err := Listen(conn, &config)
|
||||
server := ln.(*server)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
||||
Expect(server.sessions).ToNot(BeNil())
|
||||
Expect(server.scfg).ToNot(BeNil())
|
||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(acceptSTK)))
|
||||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
|
@ -361,6 +365,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(defaultAcceptSTK)))
|
||||
})
|
||||
|
||||
It("listens on a given address", func() {
|
||||
|
@ -434,3 +439,55 @@ var _ = Describe("Server", func() {
|
|||
Expect(ln.(*server).sessions).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("default source address verification", func() {
|
||||
It("accepts a token", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
stk := &STK{
|
||||
remoteAddr: "192.168.0.1",
|
||||
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(time.Second), // will expire in 1 second
|
||||
}
|
||||
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("requests verification if no token is provided", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
Expect(defaultAcceptSTK(remoteAddr, nil)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects a token if the address doesn't match", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
stk := &STK{
|
||||
remoteAddr: "127.0.0.1",
|
||||
sentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accepts a token for a remote address is not a UDP address", func() {
|
||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
stk := &STK{
|
||||
remoteAddr: "192.168.0.1:1337",
|
||||
sentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects an invalid token for a remote address is not a UDP address", func() {
|
||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
stk := &STK{
|
||||
remoteAddr: "192.168.0.1:7331", // mismatching port
|
||||
sentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects an expired token", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
stk := &STK{
|
||||
remoteAddr: "192.168.0.1",
|
||||
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second), // expired 1 second ago
|
||||
}
|
||||
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
|
10
session.go
10
session.go
|
@ -141,6 +141,15 @@ func newSession(
|
|||
s.aeadChanged = aeadChanged
|
||||
handshakeChan := make(chan handshakeEvent, 3)
|
||||
s.handshakeChan = handshakeChan
|
||||
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
|
||||
if hstk == nil {
|
||||
return config.AcceptSTK(clientAddr, nil)
|
||||
}
|
||||
return config.AcceptSTK(
|
||||
clientAddr,
|
||||
&STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime},
|
||||
)
|
||||
}
|
||||
var err error
|
||||
s.cryptoSetup, err = newCryptoSetup(
|
||||
connectionID,
|
||||
|
@ -150,6 +159,7 @@ func newSession(
|
|||
cryptoStream,
|
||||
s.connectionParameters,
|
||||
config.Versions,
|
||||
verifySourceAddr,
|
||||
aeadChanged,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -145,6 +145,7 @@ var _ = Describe("Session", func() {
|
|||
_ io.ReadWriter,
|
||||
_ handshake.ConnectionParametersManager,
|
||||
_ []protocol.VersionNumber,
|
||||
_ func(net.Addr, *handshake.STK) bool,
|
||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||
) (handshake.CryptoSetup, error) {
|
||||
aeadChanged = aeadChangedP
|
||||
|
@ -180,6 +181,64 @@ var _ = Describe("Session", func() {
|
|||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
Context("source address validation", func() {
|
||||
var (
|
||||
stkVerify func(net.Addr, *handshake.STK) bool
|
||||
paramClientAddr net.Addr
|
||||
paramSTK *STK
|
||||
)
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1000}
|
||||
|
||||
BeforeEach(func() {
|
||||
newCryptoSetup = func(
|
||||
_ protocol.ConnectionID,
|
||||
_ net.Addr,
|
||||
_ protocol.VersionNumber,
|
||||
_ *handshake.ServerConfig,
|
||||
_ io.ReadWriter,
|
||||
_ handshake.ConnectionParametersManager,
|
||||
_ []protocol.VersionNumber,
|
||||
stkFunc func(net.Addr, *handshake.STK) bool,
|
||||
_ chan<- protocol.EncryptionLevel,
|
||||
) (handshake.CryptoSetup, error) {
|
||||
stkVerify = stkFunc
|
||||
return cryptoSetup, nil
|
||||
}
|
||||
|
||||
conf := populateServerConfig(&Config{})
|
||||
conf.AcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
|
||||
paramClientAddr = clientAddr
|
||||
paramSTK = stk
|
||||
return false
|
||||
}
|
||||
pSess, _, err := newSession(
|
||||
mconn,
|
||||
protocol.Version35,
|
||||
0,
|
||||
scfg,
|
||||
conf,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
sess = pSess.(*session)
|
||||
})
|
||||
|
||||
It("calls the callback with the right parameters when the client didn't send an STK", func() {
|
||||
stkVerify(remoteAddr, nil)
|
||||
Expect(paramClientAddr).To(Equal(remoteAddr))
|
||||
Expect(paramSTK).To(BeNil())
|
||||
})
|
||||
|
||||
It("calls the callback with the STK when the client sent an STK", func() {
|
||||
stkAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
sentTime := time.Now().Add(-time.Hour)
|
||||
stkVerify(remoteAddr, &handshake.STK{SentTime: sentTime, RemoteAddr: stkAddr.String()})
|
||||
Expect(paramClientAddr).To(Equal(remoteAddr))
|
||||
Expect(paramSTK).ToNot(BeNil())
|
||||
Expect(paramSTK.remoteAddr).To(Equal(stkAddr.String()))
|
||||
Expect(paramSTK.sentTime).To(Equal(sentTime))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when handling stream frames", func() {
|
||||
It("makes new streams", func() {
|
||||
sess.handleStreamFrame(&frames.StreamFrame{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue