mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
add a quic.Config option to verify source address tokes
This commit is contained in:
parent
eb72b494b2
commit
87df63dd5f
9 changed files with 245 additions and 82 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
|
@ -33,6 +32,8 @@ type cryptoSetupServer struct {
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
|
||||||
|
acceptSTKCallback func(net.Addr, *STK) bool
|
||||||
|
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
secureAEAD crypto.AEAD
|
secureAEAD crypto.AEAD
|
||||||
forwardSecureAEAD crypto.AEAD
|
forwardSecureAEAD crypto.AEAD
|
||||||
|
@ -67,6 +68,7 @@ func NewCryptoSetup(
|
||||||
cryptoStream io.ReadWriter,
|
cryptoStream io.ReadWriter,
|
||||||
connectionParametersManager ConnectionParametersManager,
|
connectionParametersManager ConnectionParametersManager,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
|
acceptSTK func(net.Addr, *STK) bool,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
stkGenerator, err := NewSTKGenerator()
|
stkGenerator, err := NewSTKGenerator()
|
||||||
|
@ -86,6 +88,7 @@ func NewCryptoSetup(
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||||
cryptoStream: cryptoStream,
|
cryptoStream: cryptoStream,
|
||||||
connectionParameters: connectionParametersManager,
|
connectionParameters: connectionParametersManager,
|
||||||
|
acceptSTKCallback: acceptSTK,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -272,19 +275,16 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
||||||
if crypto.HashCert(cert) != xlct {
|
if crypto.HashCert(cert) != xlct {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return !h.verifySTK(cryptoData[TagSTK])
|
return !h.acceptSTK(cryptoData[TagSTK])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||||
stkTime, err := h.stkGenerator.VerifyToken(h.remoteAddr, stk)
|
stk, err := h.stkGenerator.DecodeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Debugf("STK invalid: %s", err.Error())
|
utils.Debugf("STK invalid: %s", err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if time.Now().After(stkTime.Add(protocol.STKExpiryTime)) {
|
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||||
|
@ -303,7 +303,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
||||||
TagSVID: []byte("quic-go"),
|
TagSVID: []byte("quic-go"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.verifySTK(cryptoData[TagSTK]) {
|
if h.acceptSTK(cryptoData[TagSTK]) {
|
||||||
proof, err := h.scfg.Sign(sni, chlo)
|
proof, err := h.scfg.Sign(sni, chlo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -167,6 +167,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
kexs []byte
|
kexs []byte
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
sourceAddrValid bool
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -199,11 +200,14 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
stream,
|
stream,
|
||||||
cpm,
|
cpm,
|
||||||
supportedVersions,
|
supportedVersions,
|
||||||
|
nil,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
cs = csInt.(*cryptoSetupServer)
|
cs = csInt.(*cryptoSetupServer)
|
||||||
cs.stkGenerator.stkSource = &mockStkSource{}
|
cs.stkGenerator.stkSource = &mockStkSource{}
|
||||||
|
sourceAddrValid = true
|
||||||
|
cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid }
|
||||||
cs.keyDerivation = mockKeyDerivation
|
cs.keyDerivation = mockKeyDerivation
|
||||||
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
||||||
})
|
})
|
||||||
|
@ -264,14 +268,18 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("generates REJ messages", func() {
|
It("generates REJ messages", func() {
|
||||||
|
sourceAddrValid = false
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(response).To(HavePrefix("REJ"))
|
Expect(response).To(HavePrefix("REJ"))
|
||||||
Expect(response).To(ContainSubstring("initial public"))
|
Expect(response).To(ContainSubstring("initial public"))
|
||||||
|
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
||||||
|
Expect(response).ToNot(ContainSubstring("proof"))
|
||||||
Expect(signer.gotCHLO).To(BeFalse())
|
Expect(signer.gotCHLO).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("REJ messages don't include cert or proof without STK", func() {
|
It("REJ messages don't include cert or proof without STK", func() {
|
||||||
|
sourceAddrValid = false
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(response).To(HavePrefix("REJ"))
|
Expect(response).To(HavePrefix("REJ"))
|
||||||
|
@ -281,6 +289,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("REJ messages include cert and proof with valid STK", func() {
|
It("REJ messages include cert and proof with valid STK", func() {
|
||||||
|
sourceAddrValid = true
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
||||||
TagSTK: validSTK,
|
TagSTK: validSTK,
|
||||||
TagSNI: []byte("foo"),
|
TagSNI: []byte("foo"),
|
||||||
|
@ -400,11 +409,6 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("REJ messages that have an expired STK", func() {
|
|
||||||
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes proper CHLOs", func() {
|
It("recognizes proper CHLOs", func() {
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
@ -690,6 +694,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
|
|
||||||
Context("STK verification and creation", func() {
|
Context("STK verification and creation", func() {
|
||||||
It("requires STK", func() {
|
It("requires STK", func() {
|
||||||
|
sourceAddrValid = false
|
||||||
done, err := cs.handleMessage(
|
done, err := cs.handleMessage(
|
||||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||||
map[Tag][]byte{
|
map[Tag][]byte{
|
||||||
|
@ -703,10 +708,10 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("works with proper STK", func() {
|
It("works with proper STK", func() {
|
||||||
|
sourceAddrValid = true
|
||||||
done, err := cs.handleMessage(
|
done, err := cs.handleMessage(
|
||||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||||
map[Tag][]byte{
|
map[Tag][]byte{
|
||||||
TagSTK: validSTK,
|
|
||||||
TagSNI: []byte("foo"),
|
TagSNI: []byte("foo"),
|
||||||
TagVER: versionTag,
|
TagVER: versionTag,
|
||||||
},
|
},
|
||||||
|
@ -714,19 +719,5 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(done).To(BeFalse())
|
Expect(done).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if IP does not match", func() {
|
|
||||||
done, err := cs.handleMessage(
|
|
||||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
|
||||||
map[Tag][]byte{
|
|
||||||
TagSNI: []byte("foo"),
|
|
||||||
TagSTK: []byte("token \x04\x03\x03\x01"),
|
|
||||||
TagVER: versionTag,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(done).To(BeFalse())
|
|
||||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -2,14 +2,18 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/subtle"
|
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// An STK is a source address token
|
||||||
|
type STK struct {
|
||||||
|
RemoteAddr string
|
||||||
|
SentTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// An STKGenerator generates STKs
|
// An STKGenerator generates STKs
|
||||||
type STKGenerator struct {
|
type STKGenerator struct {
|
||||||
stkSource crypto.StkSource
|
stkSource crypto.StkSource
|
||||||
|
@ -31,16 +35,20 @@ func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||||
return g.stkSource.NewToken(encodeRemoteAddr(raddr))
|
return g.stkSource.NewToken(encodeRemoteAddr(raddr))
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyToken verifies an STK token
|
// DecodeToken decodes an STK token
|
||||||
func (g *STKGenerator) VerifyToken(raddr net.Addr, data []byte) (time.Time, error) {
|
func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) {
|
||||||
data, timestamp, err := g.stkSource.DecodeToken(data)
|
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
remote, timestamp, err := g.stkSource.DecodeToken(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if subtle.ConstantTimeCompare(encodeRemoteAddr(raddr), data) != 1 {
|
return &STK{
|
||||||
return time.Time{}, errors.New("invalid source address in STK")
|
RemoteAddr: decodeRemoteAddr(remote),
|
||||||
}
|
SentTime: timestamp,
|
||||||
return timestamp, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
||||||
|
@ -55,3 +63,10 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||||
// that way it can be distinguished from an IP address
|
// that way it can be distinguished from an IP address
|
||||||
return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...)
|
return append(bytes.Repeat([]byte{0}, 16), []byte(remoteAddr.String())...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeRemoteAddr(data []byte) string {
|
||||||
|
if len(data) <= 16 {
|
||||||
|
return net.IP(data).String()
|
||||||
|
}
|
||||||
|
return string(data[16:])
|
||||||
|
}
|
||||||
|
|
|
@ -24,58 +24,49 @@ var _ = Describe("STK Generator", func() {
|
||||||
Expect(token).ToNot(BeEmpty())
|
Expect(token).ToNot(BeEmpty())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts a valid STK", func() {
|
It("works with nil tokens", func() {
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
stk, err := stkGen.DecodeToken(nil)
|
||||||
token, err := stkGen.NewToken(raddr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
t, err := stkGen.VerifyToken(raddr, token)
|
Expect(stk).To(BeNil())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("works with an IPv6 address", func() {
|
It("accepts a valid STK", func() {
|
||||||
ip := net.ParseIP("2001:db8::68")
|
ip := net.IPv4(192, 168, 0, 1)
|
||||||
|
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
stk, err := stkGen.DecodeToken(token)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(stk.RemoteAddr).To(Equal("192.168.0.1"))
|
||||||
|
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("works with an IPv6 addresses ", func() {
|
||||||
|
addresses := []string{
|
||||||
|
"2001:db8::68",
|
||||||
|
"2001:0000:4136:e378:8000:63bf:3fff:fdd2",
|
||||||
|
"2001::1",
|
||||||
|
"ff01:0:0:0:0:0:0:2",
|
||||||
|
}
|
||||||
|
for _, addr := range addresses {
|
||||||
|
ip := net.ParseIP(addr)
|
||||||
Expect(ip).ToNot(BeNil())
|
Expect(ip).ToNot(BeNil())
|
||||||
raddr := &net.UDPAddr{IP: ip, Port: 1337}
|
raddr := &net.UDPAddr{IP: ip, Port: 1337}
|
||||||
token, err := stkGen.NewToken(raddr)
|
token, err := stkGen.NewToken(raddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
t, err := stkGen.VerifyToken(raddr, token)
|
stk, err := stkGen.DecodeToken(token)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
Expect(stk.RemoteAddr).To(Equal(ip.String()))
|
||||||
|
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("does not care about the port", func() {
|
It("uses the string representation an address that is not a UDP address", func() {
|
||||||
ip := net.IPv4(192, 168, 0, 1)
|
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||||
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: ip, Port: 7331}, token)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects an STK for the wrong address", func() {
|
|
||||||
ip := net.ParseIP("1.2.3.4")
|
|
||||||
otherIP := net.ParseIP("4.3.2.1")
|
|
||||||
token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = stkGen.VerifyToken(&net.UDPAddr{IP: otherIP, Port: 1337}, token)
|
|
||||||
Expect(err).To(MatchError("invalid source address in STK"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with an address that is not a UDP address", func() {
|
|
||||||
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
token, err := stkGen.NewToken(raddr)
|
token, err := stkGen.NewToken(raddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
t, err := stkGen.VerifyToken(raddr, token)
|
stk, err := stkGen.DecodeToken(token)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
Expect(stk.RemoteAddr).To(Equal("192.168.13.37:1337"))
|
||||||
})
|
Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||||
|
|
||||||
It("uses the string representation of an address that is not a UDP address", func() {
|
|
||||||
// when using the string representation, the port matters
|
|
||||||
ip := net.IPv4(192, 168, 0, 1)
|
|
||||||
token, err := stkGen.NewToken(&net.TCPAddr{IP: ip, Port: 1337})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = stkGen.VerifyToken(&net.TCPAddr{IP: ip, Port: 7331}, token)
|
|
||||||
Expect(err).To(MatchError("invalid source address in STK"))
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
19
interface.go
19
interface.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
)
|
)
|
||||||
|
@ -45,6 +46,18 @@ type NonFWSession interface {
|
||||||
WaitUntilHandshakeComplete() error
|
WaitUntilHandshakeComplete() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An STK is a Source Address token.
|
||||||
|
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
||||||
|
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
||||||
|
type STK struct {
|
||||||
|
// The remote address this token was issued for.
|
||||||
|
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
||||||
|
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
||||||
|
remoteAddr string
|
||||||
|
// The time that the STK was issued (resolution 1 second)
|
||||||
|
sentTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// Config contains all configuration data needed for a QUIC server or client.
|
// Config contains all configuration data needed for a QUIC server or client.
|
||||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -54,9 +67,15 @@ type Config struct {
|
||||||
// Warning: This API should not be considered stable and will change soon.
|
// Warning: This API should not be considered stable and will change soon.
|
||||||
Versions []protocol.VersionNumber
|
Versions []protocol.VersionNumber
|
||||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
// Ask the server to truncate the connection ID sent in the Public Header.
|
||||||
|
// If not set, the default checks if
|
||||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||||
// Currently only valid for the client.
|
// Currently only valid for the client.
|
||||||
RequestConnectionIDTruncation bool
|
RequestConnectionIDTruncation bool
|
||||||
|
// AcceptSTK determines if an STK is accepted.
|
||||||
|
// It is called with stk = nil if the client didn't send an STK.
|
||||||
|
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours
|
||||||
|
// This option is only valid for the server.
|
||||||
|
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Listener for incoming QUIC connections
|
// A Listener for incoming QUIC connections
|
||||||
|
|
23
server.go
23
server.go
|
@ -85,15 +85,36 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
|
||||||
|
if stk == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if time.Now().After(stk.sentTime.Add(protocol.STKExpiryTime)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var sourceAddr string
|
||||||
|
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
|
||||||
|
sourceAddr = udpAddr.IP.String()
|
||||||
|
} else {
|
||||||
|
sourceAddr = clientAddr.String()
|
||||||
|
}
|
||||||
|
return sourceAddr == stk.remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
func populateServerConfig(config *Config) *Config {
|
func populateServerConfig(config *Config) *Config {
|
||||||
versions := config.Versions
|
versions := config.Versions
|
||||||
if len(versions) == 0 {
|
if len(versions) == 0 {
|
||||||
versions = protocol.SupportedVersions
|
versions = protocol.SupportedVersions
|
||||||
}
|
}
|
||||||
|
vsa := defaultAcceptSTK
|
||||||
|
if config.AcceptSTK != nil {
|
||||||
|
vsa = config.AcceptSTK
|
||||||
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
TLSConfig: config.TLSConfig,
|
TLSConfig: config.TLSConfig,
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
|
AcceptSTK: vsa,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +137,7 @@ func (s *server) serve() {
|
||||||
utils.Errorf("error handling packet: %s", err.Error())
|
utils.Errorf("error handling packet: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept returns newly openend sessions
|
// Accept returns newly openend sessions
|
||||||
func (s *server) Accept() (Session, error) {
|
func (s *server) Accept() (Session, error) {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
|
@ -342,17 +343,20 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("setups with the right values", func() {
|
It("setups with the right values", func() {
|
||||||
supportedVersions := []protocol.VersionNumber{1, 3, 5}
|
supportedVersions := []protocol.VersionNumber{1, 3, 5}
|
||||||
|
acceptSTK := func(_ net.Addr, _ *STK) bool { return true }
|
||||||
config := Config{
|
config := Config{
|
||||||
TLSConfig: &tls.Config{},
|
TLSConfig: &tls.Config{},
|
||||||
Versions: supportedVersions,
|
Versions: supportedVersions,
|
||||||
|
AcceptSTK: acceptSTK,
|
||||||
}
|
}
|
||||||
ln, err := Listen(conn, &config)
|
ln, err := Listen(conn, &config)
|
||||||
server := ln.(*server)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
server := ln.(*server)
|
||||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
||||||
Expect(server.sessions).ToNot(BeNil())
|
Expect(server.sessions).ToNot(BeNil())
|
||||||
Expect(server.scfg).ToNot(BeNil())
|
Expect(server.scfg).ToNot(BeNil())
|
||||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||||
|
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(acceptSTK)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("fills in default values if options are not set in the Config", func() {
|
It("fills in default values if options are not set in the Config", func() {
|
||||||
|
@ -361,6 +365,7 @@ var _ = Describe("Server", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
server := ln.(*server)
|
server := ln.(*server)
|
||||||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||||
|
Expect(reflect.ValueOf(server.config.AcceptSTK)).To(Equal(reflect.ValueOf(defaultAcceptSTK)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("listens on a given address", func() {
|
It("listens on a given address", func() {
|
||||||
|
@ -434,3 +439,55 @@ var _ = Describe("Server", func() {
|
||||||
Expect(ln.(*server).sessions).To(BeEmpty())
|
Expect(ln.(*server).sessions).To(BeEmpty())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var _ = Describe("default source address verification", func() {
|
||||||
|
It("accepts a token", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||||
|
stk := &STK{
|
||||||
|
remoteAddr: "192.168.0.1",
|
||||||
|
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(time.Second), // will expire in 1 second
|
||||||
|
}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("requests verification if no token is provided", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, nil)).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects a token if the address doesn't match", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||||
|
stk := &STK{
|
||||||
|
remoteAddr: "127.0.0.1",
|
||||||
|
sentTime: time.Now(),
|
||||||
|
}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("accepts a token for a remote address is not a UDP address", func() {
|
||||||
|
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||||
|
stk := &STK{
|
||||||
|
remoteAddr: "192.168.0.1:1337",
|
||||||
|
sentTime: time.Now(),
|
||||||
|
}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects an invalid token for a remote address is not a UDP address", func() {
|
||||||
|
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||||
|
stk := &STK{
|
||||||
|
remoteAddr: "192.168.0.1:7331", // mismatching port
|
||||||
|
sentTime: time.Now(),
|
||||||
|
}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects an expired token", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||||
|
stk := &STK{
|
||||||
|
remoteAddr: "192.168.0.1",
|
||||||
|
sentTime: time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second), // expired 1 second ago
|
||||||
|
}
|
||||||
|
Expect(defaultAcceptSTK(remoteAddr, stk)).To(BeFalse())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
10
session.go
10
session.go
|
@ -141,6 +141,15 @@ func newSession(
|
||||||
s.aeadChanged = aeadChanged
|
s.aeadChanged = aeadChanged
|
||||||
handshakeChan := make(chan handshakeEvent, 3)
|
handshakeChan := make(chan handshakeEvent, 3)
|
||||||
s.handshakeChan = handshakeChan
|
s.handshakeChan = handshakeChan
|
||||||
|
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
|
||||||
|
if hstk == nil {
|
||||||
|
return config.AcceptSTK(clientAddr, nil)
|
||||||
|
}
|
||||||
|
return config.AcceptSTK(
|
||||||
|
clientAddr,
|
||||||
|
&STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime},
|
||||||
|
)
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
s.cryptoSetup, err = newCryptoSetup(
|
s.cryptoSetup, err = newCryptoSetup(
|
||||||
connectionID,
|
connectionID,
|
||||||
|
@ -150,6 +159,7 @@ func newSession(
|
||||||
cryptoStream,
|
cryptoStream,
|
||||||
s.connectionParameters,
|
s.connectionParameters,
|
||||||
config.Versions,
|
config.Versions,
|
||||||
|
verifySourceAddr,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -145,6 +145,7 @@ var _ = Describe("Session", func() {
|
||||||
_ io.ReadWriter,
|
_ io.ReadWriter,
|
||||||
_ handshake.ConnectionParametersManager,
|
_ handshake.ConnectionParametersManager,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
|
_ func(net.Addr, *handshake.STK) bool,
|
||||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||||
) (handshake.CryptoSetup, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
aeadChanged = aeadChangedP
|
aeadChanged = aeadChangedP
|
||||||
|
@ -180,6 +181,64 @@ var _ = Describe("Session", func() {
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("source address validation", func() {
|
||||||
|
var (
|
||||||
|
stkVerify func(net.Addr, *handshake.STK) bool
|
||||||
|
paramClientAddr net.Addr
|
||||||
|
paramSTK *STK
|
||||||
|
)
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1000}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
newCryptoSetup = func(
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ net.Addr,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
_ *handshake.ServerConfig,
|
||||||
|
_ io.ReadWriter,
|
||||||
|
_ handshake.ConnectionParametersManager,
|
||||||
|
_ []protocol.VersionNumber,
|
||||||
|
stkFunc func(net.Addr, *handshake.STK) bool,
|
||||||
|
_ chan<- protocol.EncryptionLevel,
|
||||||
|
) (handshake.CryptoSetup, error) {
|
||||||
|
stkVerify = stkFunc
|
||||||
|
return cryptoSetup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conf := populateServerConfig(&Config{})
|
||||||
|
conf.AcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
|
||||||
|
paramClientAddr = clientAddr
|
||||||
|
paramSTK = stk
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pSess, _, err := newSession(
|
||||||
|
mconn,
|
||||||
|
protocol.Version35,
|
||||||
|
0,
|
||||||
|
scfg,
|
||||||
|
conf,
|
||||||
|
)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
sess = pSess.(*session)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("calls the callback with the right parameters when the client didn't send an STK", func() {
|
||||||
|
stkVerify(remoteAddr, nil)
|
||||||
|
Expect(paramClientAddr).To(Equal(remoteAddr))
|
||||||
|
Expect(paramSTK).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("calls the callback with the STK when the client sent an STK", func() {
|
||||||
|
stkAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
sentTime := time.Now().Add(-time.Hour)
|
||||||
|
stkVerify(remoteAddr, &handshake.STK{SentTime: sentTime, RemoteAddr: stkAddr.String()})
|
||||||
|
Expect(paramClientAddr).To(Equal(remoteAddr))
|
||||||
|
Expect(paramSTK).ToNot(BeNil())
|
||||||
|
Expect(paramSTK.remoteAddr).To(Equal(stkAddr.String()))
|
||||||
|
Expect(paramSTK.sentTime).To(Equal(sentTime))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("when handling stream frames", func() {
|
Context("when handling stream frames", func() {
|
||||||
It("makes new streams", func() {
|
It("makes new streams", func() {
|
||||||
sess.handleStreamFrame(&frames.StreamFrame{
|
sess.handleStreamFrame(&frames.StreamFrame{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue