implement a more intuitive address validation API

This commit is contained in:
Marten Seemann 2022-08-04 00:28:13 +02:00
parent 556a6e2f99
commit f2fa98c0dd
14 changed files with 352 additions and 437 deletions

View file

@ -2,11 +2,11 @@ package quic
import (
"errors"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Clone clones a Config
@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config {
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
if config.AcceptToken == nil {
config.AcceptToken = defaultAcceptToken
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
if config.MaxRetryTokenAge == 0 {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return true }
}
return config
}
@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config {
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
AcceptToken: config.AcceptToken,
MaxTokenAge: config.MaxTokenAge,
MaxRetryTokenAge: config.MaxRetryTokenAge,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,

View file

@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease":
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
@ -55,6 +55,10 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
f.Set(reflect.ValueOf(time.Hour))
case "MaxTokenAge":
f.Set(reflect.ValueOf(2 * time.Hour))
case "MaxRetryTokenAge":
f.Set(reflect.ValueOf(2 * time.Minute))
case "TokenStore":
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
case "InitialStreamReceiveWindow":
@ -100,14 +104,14 @@ var _ = Describe("Config", func() {
Context("cloning", func() {
It("clones function fields", func() {
var calledAcceptToken, calledAllowConnectionWindowIncrease bool
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
}
c2 := c1.Clone()
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
})
@ -120,26 +124,25 @@ var _ = Describe("Config", func() {
It("returns a copy", func() {
c1 := &Config{
MaxIncomingStreams: 100,
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return true },
}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
c2.RequireAddressValidation = func(net.Addr) bool { return false }
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
})
})
Context("populating", func() {
It("populates function fields", func() {
var calledAcceptToken bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
}
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})
It("copies non-function fields", func() {
@ -164,7 +167,7 @@ var _ = Describe("Config", func() {
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.AcceptToken).ToNot(BeNil())
Expect(c.RequireAddressValidation).ToNot(BeNil())
})
It("sets a default connection ID length if we didn't create the conn, for the client", func() {

View file

@ -2,7 +2,6 @@ package tokens
import (
"encoding/binary"
"fmt"
"math/rand"
"net"
"time"
@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int {
if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil {
panic("didn't expect connection IDs")
}
checkAddr(token.RemoteAddr, addr)
return 1
}
@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
panic("retry src conn ID doesn't match")
}
checkAddr(token.RemoteAddr, addr)
return 1
}
func checkAddr(tokenAddr string, addr net.Addr) {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
// For UDP addresses, we encode only the IP (not the port).
if ip := udpAddr.IP.String(); tokenAddr != ip {
fmt.Printf("%s vs %s", tokenAddr, ip)
panic("wrong remote address for a net.UDPAddr")
}
return
}
if tokenAddr != addr.String() {
fmt.Printf("%s vs %s", tokenAddr, addr.String())
panic("wrong remote address")
}
}

View file

@ -41,8 +41,8 @@ var _ = Describe("Handshake drop tests", func() {
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
})
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
if doRetry {
conf.RequireAddressValidation = func(net.Addr) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {

View file

@ -112,9 +112,7 @@ var _ = Describe("Handshake RTT tests", func() {
})
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
@ -126,9 +124,7 @@ var _ = Describe("Handshake RTT tests", func() {
})
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy()
_, err := quic.DialAddr(
@ -139,21 +135,4 @@ var _ = Describe("Handshake RTT tests", func() {
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return false
}
clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
clientConfig,
)
Expect(err).To(HaveOccurred())
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
Expect(nerr.Timeout()).To(BeTrue())
})
})

View file

@ -344,12 +344,7 @@ var _ = Describe("Handshake tests", func() {
}
BeforeEach(func() {
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil {
Expect(token.IsRetryToken).To(BeFalse())
}
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
var err error
// start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
@ -479,13 +474,7 @@ var _ = Describe("Handshake tests", func() {
Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
tokenChan := make(chan *quic.Token, 100)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil && !token.IsRetryToken {
tokenChan <- token
}
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
@ -509,7 +498,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive())
Eventually(puts).Should(Receive())
Expect(tokenChan).ToNot(Receive())
// received a token. Close this connection.
Expect(conn.CloseWithError(0, "")).To(Succeed())
@ -529,17 +517,13 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Expect(gets).To(Receive())
Expect(tokenChan).To(Receive())
Eventually(done).Should(BeClosed())
})
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
tokenChan := make(chan *quic.Token, 10)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
tokenChan <- token
return false
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
serverConfig.MaxRetryTokenAge = time.Nanosecond
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
@ -554,18 +538,6 @@ var _ = Describe("Handshake tests", func() {
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
// Receiving a Retry might lead the client to measure a very small RTT.
// Then, it sometimes would retransmit the ClientHello before receiving the ServerHello.
Expect(len(tokenChan)).To(BeNumerically(">=", 2))
token := <-tokenChan
Expect(token).To(BeNil())
token = <-tokenChan
Expect(token).ToNot(BeNil())
// If the ClientHello was retransmitted, make sure that it contained the same Retry token.
for i := 2; i < len(tokenChan); i++ {
Expect(<-tokenChan).To(Equal(token))
}
Expect(token.IsRetryToken).To(BeTrue())
})
})

View file

@ -26,7 +26,7 @@ var _ = Describe("Packetization", func() {
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
AcceptToken: func(net.Addr, *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
}),

View file

@ -56,7 +56,7 @@ var _ = Describe("0-RTT", func() {
tlsConf := getTLSConfig()
if serverConf == nil {
serverConf = getQuicConfig(&quic.Config{
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
})
serverConf.Versions = []protocol.VersionNumber{version}
}
@ -198,7 +198,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -256,7 +256,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -453,7 +453,7 @@ var _ = Describe("0-RTT", func() {
const maxStreams = 1
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams,
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
}))
tracer := newPacketTracer()
@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingUniStreams: maxStreams + 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
@ -499,7 +499,7 @@ var _ = Describe("0-RTT", func() {
const maxStreams = 42
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams,
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
}))
tracer := newPacketTracer()
@ -508,7 +508,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingStreams: maxStreams - 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
@ -538,7 +538,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -560,7 +560,7 @@ var _ = Describe("0-RTT", func() {
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Versions: []protocol.VersionNumber{version},
})
addFlowControlLimit(firstConf, 3)
@ -568,7 +568,7 @@ var _ = Describe("0-RTT", func() {
secondConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
})
addFlowControlLimit(secondConf, 100)
@ -723,7 +723,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)

View file

@ -26,16 +26,6 @@ const (
Version2 = protocol.Version2
)
// A Token can be used to verify the ownership of the client address.
type Token struct {
// IsRetryToken encodes how the client received the token. There are two ways:
// * In a Retry packet sent when trying to establish a new connection.
// * In a NEW_TOKEN frame on a previous connection.
IsRetryToken bool
RemoteAddr string
SentTime time.Time
}
// A ClientToken is a token received by the client.
// It can be used to skip address validation on future connection attempts.
type ClientToken struct {
@ -233,14 +223,18 @@ type Config struct {
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration
// AcceptToken determines if a Token is accepted.
// It is called with token = nil if the client didn't send a token.
// If not set, a default verification function is used:
// * it verifies that the address matches, and
// * if the token is a retry token, that it was issued within the last 5 seconds
// * else, that it was issued within the last 24 hours.
// This option is only valid for the server.
AcceptToken func(clientAddr net.Addr, token *Token) bool
// RequireAddressValidation determines if a QUIC Retry packet is sent.
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool
// MaxRetryTokenAge is the maximum age of a Retry token.
// If not set, it defaults to 5 seconds. Only valid for a server.
MaxRetryTokenAge time.Duration
// MaxTokenAge is the maximum age of the token presented during the handshake,
// for tokens that were issued on a previous connection.
// If not set, it defaults to 24 hours. Only valid for a server.
MaxTokenAge time.Duration
// The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set

View file

@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"encoding/asn1"
"fmt"
"io"
@ -18,13 +19,17 @@ const (
// A Token is derived from the client address and can be used to verify the ownership of this address.
type Token struct {
IsRetryToken bool
RemoteAddr string
SentTime time.Time
encodedRemoteAddr []byte
// only set for retry tokens
OriginalDestConnectionID protocol.ConnectionID
RetrySrcConnectionID protocol.ConnectionID
}
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
IsRetryToken bool
@ -102,8 +107,8 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
}
token := &Token{
IsRetryToken: t.IsRetryToken,
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
SentTime: time.Unix(0, t.Timestamp),
encodedRemoteAddr: t.RemoteAddr,
}
if t.IsRetryToken {
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
@ -119,16 +124,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
}
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the token
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a token that we generated.
// Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == tokenPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -35,16 +35,13 @@ var _ = Describe("Token Generator", func() {
})
It("accepts a valid token", func() {
ip := net.IPv4(192, 168, 0, 1)
tokenEnc, err := tokenGen.NewRetryToken(
&net.UDPAddr{IP: ip, Port: 1337},
nil,
nil,
)
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil)
Expect(err).ToNot(HaveOccurred())
token, err := tokenGen.DecodeToken(tokenEnc)
Expect(err).ToNot(HaveOccurred())
Expect(token.RemoteAddr).To(Equal("192.168.0.1"))
Expect(token.ValidateRemoteAddr(addr)).To(BeTrue())
Expect(token.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})).To(BeFalse())
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
Expect(token.OriginalDestConnectionID.Len()).To(BeZero())
Expect(token.RetrySrcConnectionID.Len()).To(BeZero())
@ -110,7 +107,7 @@ var _ = Describe("Token Generator", func() {
Expect(err).ToNot(HaveOccurred())
token, err := tokenGen.DecodeToken(tokenEnc)
Expect(err).ToNot(HaveOccurred())
Expect(token.RemoteAddr).To(Equal(ip.String()))
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
}
})
@ -121,7 +118,8 @@ var _ = Describe("Token Generator", func() {
Expect(err).ToNot(HaveOccurred())
token, err := tokenGen.DecodeToken(tokenEnc)
Expect(err).ToNot(HaveOccurred())
Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337"))
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
Expect(token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})).To(BeFalse())
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
})
})

View file

@ -44,7 +44,7 @@ func main() {
}
// a quic.Config that doesn't do a Retry
quicConf := &quic.Config{
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
Tracer: qlog.NewTracer(getLogWriter),
}
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
@ -58,15 +58,11 @@ func main() {
}
switch testcase {
case "versionnegotiation", "handshake", "transfer", "resumption", "zerortt", "multiconnect":
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "zerortt", "multiconnect":
err = runHTTP09Server(quicConf)
case "chacha20":
tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
err = runHTTP09Server(quicConf)
case "retry":
// By default, quic-go performs a Retry on every incoming connection.
quicConf.AcceptToken = nil
err = runHTTP09Server(quicConf)
case "http3":
err = runHTTP3Server(quicConf)
default:

View file

@ -241,26 +241,6 @@ func (s *baseServer) run() {
}
}
var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool {
if token == nil {
return false
}
validity := protocol.TokenValidity
if token.IsRetryToken {
validity = protocol.RetryTokenValidity
}
if time.Now().After(token.SentTime.Add(validity)) {
return false
}
var sourceAddr string
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP.String()
} else {
sourceAddr = clientAddr.String()
}
return sourceAddr == token.RemoteAddr
}
// Accept returns connections that already completed the handshake.
// It is only valid if acceptEarlyConns is false.
func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
@ -405,33 +385,45 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
var (
token *Token
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
)
origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
tok, err := s.tokenGenerator.DecodeToken(hdr.Token)
if err == nil {
token = &Token{
IsRetryToken: c.IsRetryToken,
RemoteAddr: c.RemoteAddr,
SentTime: c.SentTime,
if tok.IsRetryToken {
origDestConnID = tok.OriginalDestConnectionID
retrySrcConnID = &tok.RetrySrcConnectionID
}
if token.IsRetryToken {
origDestConnID = c.OriginalDestConnectionID
retrySrcConnID = &c.RetrySrcConnectionID
token = tok
}
}
}
if !s.config.AcceptToken(p.remoteAddr, token) {
if token != nil {
addrIsValid := token.ValidateRemoteAddr(p.remoteAddr)
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) {
token = nil
} else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) {
// For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or
// * the token is invalid, in case of a retry token.
go func() {
defer p.buffer.Release()
if token != nil && token.IsRetryToken {
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
return
}
}()
return nil
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
go func() {
defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error sending Retry: %s", err)
}

View file

@ -126,22 +126,22 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken)))
Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second))
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
Expect(server.config.KeepAlivePeriod).To(BeZero())
// stop the listener
Expect(ln.Close()).To(Succeed())
})
It("setups with the right values", func() {
supportedVersions := []protocol.VersionNumber{protocol.VersionTLS}
acceptToken := func(_ net.Addr, _ *Token) bool { return true }
requireAddrVal := func(net.Addr) bool { return true }
config := Config{
Versions: supportedVersions,
AcceptToken: acceptToken,
HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second,
StatelessResetKey: []byte("foobar"),
RequireAddressValidation: requireAddrVal,
}
ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred())
@ -150,7 +150,7 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken)))
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar")))
// stop the listener
@ -239,60 +239,8 @@ var _ = Describe("Server", func() {
time.Sleep(50 * time.Millisecond)
})
It("decodes the token from the Token field", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).ToNot(BeNil())
close(done)
return false
}
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("passes an empty token to the callback, if decoding fails", func() {
raddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 13, 37),
Port: 1337,
}
done := make(chan struct{})
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
Expect(addr).To(Equal(raddr))
Expect(token).To(BeNil())
close(done)
return false
}
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: []byte("foobar"),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("creates a connection when the token is accepted", func() {
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
retryToken, err := serv.tokenGenerator.NewRetryToken(
&net.UDPAddr{},
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
@ -469,8 +417,8 @@ var _ = Describe("Server", func() {
time.Sleep(scaleDuration(20 * time.Millisecond))
})
It("replies with a Retry packet, if a Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
It("replies with a Retry packet, if a token is required", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
@ -502,81 +450,8 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
return len(b), nil
})
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(done).Should(BeClosed())
})
It("creates a connection, if no Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
It("creates a connection, if no token is required", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
@ -659,7 +534,7 @@ var _ = Describe("Server", func() {
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually
serv.newConn = func(
@ -713,7 +588,7 @@ var _ = Describe("Server", func() {
})
It("only creates a single connection for a duplicate Initial", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
var createdConn bool
conn := NewMockQuicConn(mockCtrl)
serv.newConn = func(
@ -745,7 +620,7 @@ var _ = Describe("Server", func() {
})
It("rejects new connection attempts if the accept queue is full", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func(
_ sendConn,
@ -813,7 +688,7 @@ var _ = Describe("Server", func() {
})
It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
@ -877,6 +752,200 @@ var _ = Describe("Server", func() {
})
})
Context("token validation", func() {
checkInvalidToken := func(b []byte, origHdr *wire.Header) {
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(false, origHdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
}
It("decodes the token from the token field", func() {
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkInvalidToken(b, hdr)
return len(b), nil
})
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.config.MaxRetryTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(frames).To(HaveLen(1))
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := frames[0].(*logging.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkInvalidToken(b, hdr)
return len(b), nil
})
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.config.MaxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
return len(b), nil
})
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(done).Should(BeClosed())
})
})
Context("accepting connections", func() {
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")
@ -930,7 +999,7 @@ var _ = Describe("Server", func() {
}()
ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -1004,7 +1073,7 @@ var _ = Describe("Server", func() {
}()
ready := make(chan struct{})
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -1045,7 +1114,7 @@ var _ = Describe("Server", func() {
})
It("rejects new connection attempts if the accept queue is full", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
serv.newConn = func(
@ -1106,7 +1175,7 @@ var _ = Describe("Server", func() {
})
It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
@ -1166,72 +1235,3 @@ var _ = Describe("Server", func() {
})
})
})
var _ = Describe("default source address verification", func() {
It("accepts a token", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
token := &Token{
IsRetryToken: true,
RemoteAddr: "192.168.0.1",
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
})
It("requests verification if no token is provided", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse())
})
It("rejects a token if the address doesn't match", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
token := &Token{
IsRetryToken: true,
RemoteAddr: "127.0.0.1",
SentTime: time.Now(),
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
})
It("accepts a token for a remote address is not a UDP address", func() {
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
token := &Token{
IsRetryToken: true,
RemoteAddr: "192.168.0.1:1337",
SentTime: time.Now(),
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
})
It("rejects an invalid token for a remote address is not a UDP address", func() {
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
token := &Token{
IsRetryToken: true,
RemoteAddr: "192.168.0.1:7331", // mismatching port
SentTime: time.Now(),
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
})
It("rejects an expired token", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
token := &Token{
IsRetryToken: true,
RemoteAddr: "192.168.0.1",
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
})
It("accepts a non-retry token", func() {
Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity))
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
token := &Token{
IsRetryToken: false,
RemoteAddr: "192.168.0.1",
// if this was a retry token, it would have expired one second ago
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second),
}
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
})
})