mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
implement a more intuitive address validation API
This commit is contained in:
parent
556a6e2f99
commit
f2fa98c0dd
14 changed files with 352 additions and 437 deletions
18
config.go
18
config.go
|
@ -2,11 +2,11 @@ package quic
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Clone clones a Config
|
||||
|
@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config {
|
|||
if config.ConnectionIDLength == 0 {
|
||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
if config.AcceptToken == nil {
|
||||
config.AcceptToken = defaultAcceptToken
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
if config.MaxRetryTokenAge == 0 {
|
||||
config.MaxRetryTokenAge = protocol.RetryTokenValidity
|
||||
}
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config {
|
|||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
AcceptToken: config.AcceptToken,
|
||||
MaxTokenAge: config.MaxTokenAge,
|
||||
MaxRetryTokenAge: config.MaxRetryTokenAge,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
|
|
|
@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
|
|||
}
|
||||
|
||||
switch fn := typ.Field(i).Name; fn {
|
||||
case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease":
|
||||
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
|
||||
// Can't compare functions.
|
||||
case "Versions":
|
||||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||
|
@ -55,6 +55,10 @@ var _ = Describe("Config", func() {
|
|||
f.Set(reflect.ValueOf(time.Second))
|
||||
case "MaxIdleTimeout":
|
||||
f.Set(reflect.ValueOf(time.Hour))
|
||||
case "MaxTokenAge":
|
||||
f.Set(reflect.ValueOf(2 * time.Hour))
|
||||
case "MaxRetryTokenAge":
|
||||
f.Set(reflect.ValueOf(2 * time.Minute))
|
||||
case "TokenStore":
|
||||
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
|
||||
case "InitialStreamReceiveWindow":
|
||||
|
@ -100,14 +104,14 @@ var _ = Describe("Config", func() {
|
|||
|
||||
Context("cloning", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAcceptToken, calledAllowConnectionWindowIncrease bool
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
|
||||
c1 := &Config{
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
||||
Expect(calledAcceptToken).To(BeTrue())
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
c2.AllowConnectionWindowIncrease(nil, 1234)
|
||||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||
})
|
||||
|
@ -120,26 +124,25 @@ var _ = Describe("Config", func() {
|
|||
It("returns a copy", func() {
|
||||
c1 := &Config{
|
||||
MaxIncomingStreams: 100,
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.MaxIncomingStreams = 200
|
||||
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
c2.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
||||
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
|
||||
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("populating", func() {
|
||||
It("populates function fields", func() {
|
||||
var calledAcceptToken bool
|
||||
c1 := &Config{
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
||||
}
|
||||
var calledAddrValidation bool
|
||||
c1 := &Config{}
|
||||
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||
c2 := populateConfig(c1)
|
||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
||||
Expect(calledAcceptToken).To(BeTrue())
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
})
|
||||
|
||||
It("copies non-function fields", func() {
|
||||
|
@ -164,7 +167,7 @@ var _ = Describe("Config", func() {
|
|||
It("populates empty fields with default values, for the server", func() {
|
||||
c := populateServerConfig(&Config{})
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
Expect(c.AcceptToken).ToNot(BeNil())
|
||||
Expect(c.RequireAddressValidation).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
||||
|
|
|
@ -2,7 +2,6 @@ package tokens
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int {
|
|||
if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil {
|
||||
panic("didn't expect connection IDs")
|
||||
}
|
||||
checkAddr(token.RemoteAddr, addr)
|
||||
return 1
|
||||
}
|
||||
|
||||
|
@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
|
|||
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
|
||||
panic("retry src conn ID doesn't match")
|
||||
}
|
||||
checkAddr(token.RemoteAddr, addr)
|
||||
return 1
|
||||
}
|
||||
|
||||
func checkAddr(tokenAddr string, addr net.Addr) {
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
// For UDP addresses, we encode only the IP (not the port).
|
||||
if ip := udpAddr.IP.String(); tokenAddr != ip {
|
||||
fmt.Printf("%s vs %s", tokenAddr, ip)
|
||||
panic("wrong remote address for a net.UDPAddr")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tokenAddr != addr.String() {
|
||||
fmt.Printf("%s vs %s", tokenAddr, addr.String())
|
||||
panic("wrong remote address")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,8 +41,8 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
HandshakeIdleTimeout: timeout,
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
})
|
||||
if !doRetry {
|
||||
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
|
||||
if doRetry {
|
||||
conf.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
}
|
||||
var tlsConf *tls.Config
|
||||
if longCertChain {
|
||||
|
|
|
@ -112,9 +112,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
})
|
||||
|
||||
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
|
||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
||||
return true
|
||||
}
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
|
@ -126,9 +124,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
})
|
||||
|
||||
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
|
||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
||||
return true
|
||||
}
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
|
@ -139,21 +135,4 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the server never accepts the token", func() {
|
||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
||||
return false
|
||||
}
|
||||
clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(nerr.Timeout()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -344,12 +344,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
||||
if token != nil {
|
||||
Expect(token.IsRetryToken).To(BeFalse())
|
||||
}
|
||||
return true
|
||||
}
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
var err error
|
||||
// start the server, but don't call Accept
|
||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
|
@ -479,13 +474,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
Context("using tokens", func() {
|
||||
It("uses tokens provided in NEW_TOKEN frames", func() {
|
||||
tokenChan := make(chan *quic.Token, 100)
|
||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
||||
if token != nil && !token.IsRetryToken {
|
||||
tokenChan <- token
|
||||
}
|
||||
return true
|
||||
}
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -509,7 +498,6 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(gets).To(Receive())
|
||||
Eventually(puts).Should(Receive())
|
||||
Expect(tokenChan).ToNot(Receive())
|
||||
// received a token. Close this connection.
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
|
||||
|
@ -529,17 +517,13 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
Expect(gets).To(Receive())
|
||||
Expect(tokenChan).To(Receive())
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
|
||||
tokenChan := make(chan *quic.Token, 10)
|
||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
||||
tokenChan <- token
|
||||
return false
|
||||
}
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serverConfig.MaxRetryTokenAge = time.Nanosecond
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -554,18 +538,6 @@ var _ = Describe("Handshake tests", func() {
|
|||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
|
||||
// Receiving a Retry might lead the client to measure a very small RTT.
|
||||
// Then, it sometimes would retransmit the ClientHello before receiving the ServerHello.
|
||||
Expect(len(tokenChan)).To(BeNumerically(">=", 2))
|
||||
token := <-tokenChan
|
||||
Expect(token).To(BeNil())
|
||||
token = <-tokenChan
|
||||
Expect(token).ToNot(BeNil())
|
||||
// If the ClientHello was retransmitted, make sure that it contained the same Retry token.
|
||||
for i := 2; i < len(tokenChan); i++ {
|
||||
Expect(<-tokenChan).To(Equal(token))
|
||||
}
|
||||
Expect(token.IsRetryToken).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ var _ = Describe("Packetization", func() {
|
|||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
AcceptToken: func(net.Addr, *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
|
||||
}),
|
||||
|
|
|
@ -56,7 +56,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf := getTLSConfig()
|
||||
if serverConf == nil {
|
||||
serverConf = getQuicConfig(&quic.Config{
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
})
|
||||
serverConf.Versions = []protocol.VersionNumber{version}
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
)
|
||||
|
@ -256,7 +256,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
)
|
||||
|
@ -453,7 +453,7 @@ var _ = Describe("0-RTT", func() {
|
|||
const maxStreams = 1
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams,
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
|
@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
|
@ -499,7 +499,7 @@ var _ = Describe("0-RTT", func() {
|
|||
const maxStreams = 42
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams,
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
|
@ -508,7 +508,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
MaxIncomingStreams: maxStreams - 1,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
|
@ -538,7 +538,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
)
|
||||
|
@ -560,7 +560,7 @@ var _ = Describe("0-RTT", func() {
|
|||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||
tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
|
@ -568,7 +568,7 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
secondConf := getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
})
|
||||
addFlowControlLimit(secondConf, 100)
|
||||
|
@ -723,7 +723,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
}),
|
||||
)
|
||||
|
|
30
interface.go
30
interface.go
|
@ -26,16 +26,6 @@ const (
|
|||
Version2 = protocol.Version2
|
||||
)
|
||||
|
||||
// A Token can be used to verify the ownership of the client address.
|
||||
type Token struct {
|
||||
// IsRetryToken encodes how the client received the token. There are two ways:
|
||||
// * In a Retry packet sent when trying to establish a new connection.
|
||||
// * In a NEW_TOKEN frame on a previous connection.
|
||||
IsRetryToken bool
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// A ClientToken is a token received by the client.
|
||||
// It can be used to skip address validation on future connection attempts.
|
||||
type ClientToken struct {
|
||||
|
@ -233,14 +223,18 @@ type Config struct {
|
|||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// AcceptToken determines if a Token is accepted.
|
||||
// It is called with token = nil if the client didn't send a token.
|
||||
// If not set, a default verification function is used:
|
||||
// * it verifies that the address matches, and
|
||||
// * if the token is a retry token, that it was issued within the last 5 seconds
|
||||
// * else, that it was issued within the last 24 hours.
|
||||
// This option is only valid for the server.
|
||||
AcceptToken func(clientAddr net.Addr, token *Token) bool
|
||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||
// If not set, every client is forced to prove its remote address.
|
||||
RequireAddressValidation func(net.Addr) bool
|
||||
// MaxRetryTokenAge is the maximum age of a Retry token.
|
||||
// If not set, it defaults to 5 seconds. Only valid for a server.
|
||||
MaxRetryTokenAge time.Duration
|
||||
// MaxTokenAge is the maximum age of the token presented during the handshake,
|
||||
// for tokens that were issued on a previous connection.
|
||||
// If not set, it defaults to 24 hours. Only valid for a server.
|
||||
MaxTokenAge time.Duration
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -18,13 +19,17 @@ const (
|
|||
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Token struct {
|
||||
IsRetryToken bool
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
encodedRemoteAddr []byte
|
||||
// only set for retry tokens
|
||||
OriginalDestConnectionID protocol.ConnectionID
|
||||
RetrySrcConnectionID protocol.ConnectionID
|
||||
}
|
||||
|
||||
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
IsRetryToken bool
|
||||
|
@ -102,8 +107,8 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
|
|||
}
|
||||
token := &Token{
|
||||
IsRetryToken: t.IsRetryToken,
|
||||
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
|
||||
SentTime: time.Unix(0, t.Timestamp),
|
||||
encodedRemoteAddr: t.RemoteAddr,
|
||||
}
|
||||
if t.IsRetryToken {
|
||||
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
|
||||
|
@ -119,16 +124,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
|||
}
|
||||
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
|
||||
// decodeRemoteAddr decodes the remote address saved in the token
|
||||
func decodeRemoteAddr(data []byte) string {
|
||||
// data will never be empty for a token that we generated.
|
||||
// Check it to be on the safe side
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
if data[0] == tokenPrefixIP {
|
||||
return net.IP(data[1:]).String()
|
||||
}
|
||||
return string(data[1:])
|
||||
}
|
||||
|
|
|
@ -35,16 +35,13 @@ var _ = Describe("Token Generator", func() {
|
|||
})
|
||||
|
||||
It("accepts a valid token", func() {
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
tokenEnc, err := tokenGen.NewRetryToken(
|
||||
&net.UDPAddr{IP: ip, Port: 1337},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(token.RemoteAddr).To(Equal("192.168.0.1"))
|
||||
Expect(token.ValidateRemoteAddr(addr)).To(BeTrue())
|
||||
Expect(token.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})).To(BeFalse())
|
||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||
Expect(token.OriginalDestConnectionID.Len()).To(BeZero())
|
||||
Expect(token.RetrySrcConnectionID.Len()).To(BeZero())
|
||||
|
@ -110,7 +107,7 @@ var _ = Describe("Token Generator", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(token.RemoteAddr).To(Equal(ip.String()))
|
||||
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
|
||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||
}
|
||||
})
|
||||
|
@ -121,7 +118,8 @@ var _ = Describe("Token Generator", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337"))
|
||||
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
|
||||
Expect(token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})).To(BeFalse())
|
||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -44,7 +44,7 @@ func main() {
|
|||
}
|
||||
// a quic.Config that doesn't do a Retry
|
||||
quicConf := &quic.Config{
|
||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
|
||||
Tracer: qlog.NewTracer(getLogWriter),
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
||||
|
@ -58,15 +58,11 @@ func main() {
|
|||
}
|
||||
|
||||
switch testcase {
|
||||
case "versionnegotiation", "handshake", "transfer", "resumption", "zerortt", "multiconnect":
|
||||
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "zerortt", "multiconnect":
|
||||
err = runHTTP09Server(quicConf)
|
||||
case "chacha20":
|
||||
tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
|
||||
err = runHTTP09Server(quicConf)
|
||||
case "retry":
|
||||
// By default, quic-go performs a Retry on every incoming connection.
|
||||
quicConf.AcceptToken = nil
|
||||
err = runHTTP09Server(quicConf)
|
||||
case "http3":
|
||||
err = runHTTP3Server(quicConf)
|
||||
default:
|
||||
|
|
56
server.go
56
server.go
|
@ -241,26 +241,6 @@ func (s *baseServer) run() {
|
|||
}
|
||||
}
|
||||
|
||||
var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool {
|
||||
if token == nil {
|
||||
return false
|
||||
}
|
||||
validity := protocol.TokenValidity
|
||||
if token.IsRetryToken {
|
||||
validity = protocol.RetryTokenValidity
|
||||
}
|
||||
if time.Now().After(token.SentTime.Add(validity)) {
|
||||
return false
|
||||
}
|
||||
var sourceAddr string
|
||||
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
|
||||
sourceAddr = udpAddr.IP.String()
|
||||
} else {
|
||||
sourceAddr = clientAddr.String()
|
||||
}
|
||||
return sourceAddr == token.RemoteAddr
|
||||
}
|
||||
|
||||
// Accept returns connections that already completed the handshake.
|
||||
// It is only valid if acceptEarlyConns is false.
|
||||
func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
|
||||
|
@ -405,33 +385,45 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
}
|
||||
|
||||
var (
|
||||
token *Token
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
)
|
||||
origDestConnID := hdr.DestConnectionID
|
||||
if len(hdr.Token) > 0 {
|
||||
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
|
||||
tok, err := s.tokenGenerator.DecodeToken(hdr.Token)
|
||||
if err == nil {
|
||||
token = &Token{
|
||||
IsRetryToken: c.IsRetryToken,
|
||||
RemoteAddr: c.RemoteAddr,
|
||||
SentTime: c.SentTime,
|
||||
if tok.IsRetryToken {
|
||||
origDestConnID = tok.OriginalDestConnectionID
|
||||
retrySrcConnID = &tok.RetrySrcConnectionID
|
||||
}
|
||||
if token.IsRetryToken {
|
||||
origDestConnID = c.OriginalDestConnectionID
|
||||
retrySrcConnID = &c.RetrySrcConnectionID
|
||||
token = tok
|
||||
}
|
||||
}
|
||||
}
|
||||
if !s.config.AcceptToken(p.remoteAddr, token) {
|
||||
if token != nil {
|
||||
addrIsValid := token.ValidateRemoteAddr(p.remoteAddr)
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
// This also means we might send a Retry later.
|
||||
if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) {
|
||||
token = nil
|
||||
} else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) {
|
||||
// For Retry tokens, we send an INVALID_ERROR if
|
||||
// * the token is too old, or
|
||||
// * the token is invalid, in case of a retry token.
|
||||
go func() {
|
||||
defer p.buffer.Release()
|
||||
if token != nil && token.IsRetryToken {
|
||||
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
|
||||
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
||||
go func() {
|
||||
defer p.buffer.Release()
|
||||
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
||||
s.logger.Debugf("Error sending Retry: %s", err)
|
||||
}
|
||||
|
|
424
server_test.go
424
server_test.go
|
@ -126,22 +126,22 @@ var _ = Describe("Server", func() {
|
|||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken)))
|
||||
Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second))
|
||||
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
|
||||
Expect(server.config.KeepAlivePeriod).To(BeZero())
|
||||
// stop the listener
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("setups with the right values", func() {
|
||||
supportedVersions := []protocol.VersionNumber{protocol.VersionTLS}
|
||||
acceptToken := func(_ net.Addr, _ *Token) bool { return true }
|
||||
requireAddrVal := func(net.Addr) bool { return true }
|
||||
config := Config{
|
||||
Versions: supportedVersions,
|
||||
AcceptToken: acceptToken,
|
||||
HandshakeIdleTimeout: 1337 * time.Hour,
|
||||
MaxIdleTimeout: 42 * time.Minute,
|
||||
KeepAlivePeriod: 5 * time.Second,
|
||||
StatelessResetKey: []byte("foobar"),
|
||||
RequireAddressValidation: requireAddrVal,
|
||||
}
|
||||
ln, err := Listen(conn, tlsConf, &config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -150,7 +150,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
|
||||
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
||||
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken)))
|
||||
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
|
||||
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
||||
Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar")))
|
||||
// stop the listener
|
||||
|
@ -239,60 +239,8 @@ var _ = Describe("Server", func() {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("decodes the token from the Token field", func() {
|
||||
raddr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 13, 37),
|
||||
Port: 1337,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
|
||||
Expect(addr).To(Equal(raddr))
|
||||
Expect(token).ToNot(BeNil())
|
||||
close(done)
|
||||
return false
|
||||
}
|
||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: token,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("passes an empty token to the callback, if decoding fails", func() {
|
||||
raddr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 13, 37),
|
||||
Port: 1337,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
|
||||
Expect(addr).To(Equal(raddr))
|
||||
Expect(token).To(BeNil())
|
||||
close(done)
|
||||
return false
|
||||
}
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: []byte("foobar"),
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection when the token is accepted", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
||||
&net.UDPAddr{},
|
||||
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
|
||||
|
@ -469,8 +417,8 @@ var _ = Describe("Server", func() {
|
|||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
})
|
||||
|
||||
It("replies with a Retry packet, if a Token is required", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
It("replies with a Retry packet, if a token is required", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -502,81 +450,8 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
replyHdr := parseHeader(b)
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
||||
extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := f.(*wire.ConnectionCloseFrame)
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
||||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection, if no Token is required", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
It("creates a connection, if no token is required", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -659,7 +534,7 @@ var _ = Describe("Server", func() {
|
|||
}).AnyTimes()
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||
|
||||
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
acceptConn := make(chan struct{})
|
||||
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
||||
serv.newConn = func(
|
||||
|
@ -713,7 +588,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("only creates a single connection for a duplicate Initial", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
var createdConn bool
|
||||
conn := NewMockQuicConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
|
@ -745,7 +620,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
|
@ -813,7 +688,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't accept new connections if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -877,6 +752,200 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("token validation", func() {
|
||||
checkInvalidToken := func(b []byte, origHdr *wire.Header) {
|
||||
replyHdr := parseHeader(b)
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
|
||||
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
||||
extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := wire.NewFrameParser(false, origHdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := f.(*wire.ConnectionCloseFrame)
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
||||
}
|
||||
|
||||
It("decodes the token from the token field", func() {
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: token,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.config.MaxRetryTokenAge = time.Millisecond
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
replyHdr := parseHeader(b)
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.config.MaxTokenAge = time.Millisecond
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewToken(raddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Token: token,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
||||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("accepting connections", func() {
|
||||
It("returns Accept when an error occurs", func() {
|
||||
testErr := errors.New("test err")
|
||||
|
@ -930,7 +999,7 @@ var _ = Describe("Server", func() {
|
|||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) // handshake context
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
|
@ -1004,7 +1073,7 @@ var _ = Describe("Server", func() {
|
|||
}()
|
||||
|
||||
ready := make(chan struct{})
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
|
@ -1045,7 +1114,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
serv.newConn = func(
|
||||
|
@ -1106,7 +1175,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't accept new connections if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -1166,72 +1235,3 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("default source address verification", func() {
|
||||
It("accepts a token", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
token := &Token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: "192.168.0.1",
|
||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("requests verification if no token is provided", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects a token if the address doesn't match", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
token := &Token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: "127.0.0.1",
|
||||
SentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accepts a token for a remote address is not a UDP address", func() {
|
||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
token := &Token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: "192.168.0.1:1337",
|
||||
SentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects an invalid token for a remote address is not a UDP address", func() {
|
||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
token := &Token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: "192.168.0.1:7331", // mismatching port
|
||||
SentTime: time.Now(),
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects an expired token", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
token := &Token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: "192.168.0.1",
|
||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accepts a non-retry token", func() {
|
||||
Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity))
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
||||
token := &Token{
|
||||
IsRetryToken: false,
|
||||
RemoteAddr: "192.168.0.1",
|
||||
// if this was a retry token, it would have expired one second ago
|
||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second),
|
||||
}
|
||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue