mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
implement a more intuitive address validation API
This commit is contained in:
parent
556a6e2f99
commit
f2fa98c0dd
14 changed files with 352 additions and 437 deletions
18
config.go
18
config.go
|
@ -2,11 +2,11 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Clone clones a Config
|
// Clone clones a Config
|
||||||
|
@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config {
|
||||||
if config.ConnectionIDLength == 0 {
|
if config.ConnectionIDLength == 0 {
|
||||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||||
}
|
}
|
||||||
if config.AcceptToken == nil {
|
if config.MaxTokenAge == 0 {
|
||||||
config.AcceptToken = defaultAcceptToken
|
config.MaxTokenAge = protocol.TokenValidity
|
||||||
|
}
|
||||||
|
if config.MaxRetryTokenAge == 0 {
|
||||||
|
config.MaxRetryTokenAge = protocol.RetryTokenValidity
|
||||||
|
}
|
||||||
|
if config.RequireAddressValidation == nil {
|
||||||
|
config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config {
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||||
MaxIdleTimeout: idleTimeout,
|
MaxIdleTimeout: idleTimeout,
|
||||||
AcceptToken: config.AcceptToken,
|
MaxTokenAge: config.MaxTokenAge,
|
||||||
|
MaxRetryTokenAge: config.MaxRetryTokenAge,
|
||||||
|
RequireAddressValidation: config.RequireAddressValidation,
|
||||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||||
|
|
|
@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch fn := typ.Field(i).Name; fn {
|
switch fn := typ.Field(i).Name; fn {
|
||||||
case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease":
|
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
|
||||||
// Can't compare functions.
|
// Can't compare functions.
|
||||||
case "Versions":
|
case "Versions":
|
||||||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||||
|
@ -55,6 +55,10 @@ var _ = Describe("Config", func() {
|
||||||
f.Set(reflect.ValueOf(time.Second))
|
f.Set(reflect.ValueOf(time.Second))
|
||||||
case "MaxIdleTimeout":
|
case "MaxIdleTimeout":
|
||||||
f.Set(reflect.ValueOf(time.Hour))
|
f.Set(reflect.ValueOf(time.Hour))
|
||||||
|
case "MaxTokenAge":
|
||||||
|
f.Set(reflect.ValueOf(2 * time.Hour))
|
||||||
|
case "MaxRetryTokenAge":
|
||||||
|
f.Set(reflect.ValueOf(2 * time.Minute))
|
||||||
case "TokenStore":
|
case "TokenStore":
|
||||||
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
|
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
|
||||||
case "InitialStreamReceiveWindow":
|
case "InitialStreamReceiveWindow":
|
||||||
|
@ -100,14 +104,14 @@ var _ = Describe("Config", func() {
|
||||||
|
|
||||||
Context("cloning", func() {
|
Context("cloning", func() {
|
||||||
It("clones function fields", func() {
|
It("clones function fields", func() {
|
||||||
var calledAcceptToken, calledAllowConnectionWindowIncrease bool
|
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
|
||||||
c1 := &Config{
|
c1 := &Config{
|
||||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
|
||||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||||
|
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||||
}
|
}
|
||||||
c2 := c1.Clone()
|
c2 := c1.Clone()
|
||||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||||
Expect(calledAcceptToken).To(BeTrue())
|
Expect(calledAddrValidation).To(BeTrue())
|
||||||
c2.AllowConnectionWindowIncrease(nil, 1234)
|
c2.AllowConnectionWindowIncrease(nil, 1234)
|
||||||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
@ -120,26 +124,25 @@ var _ = Describe("Config", func() {
|
||||||
It("returns a copy", func() {
|
It("returns a copy", func() {
|
||||||
c1 := &Config{
|
c1 := &Config{
|
||||||
MaxIncomingStreams: 100,
|
MaxIncomingStreams: 100,
|
||||||
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||||
}
|
}
|
||||||
c2 := c1.Clone()
|
c2 := c1.Clone()
|
||||||
c2.MaxIncomingStreams = 200
|
c2.MaxIncomingStreams = 200
|
||||||
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
c2.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
|
|
||||||
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
||||||
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
|
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("populating", func() {
|
Context("populating", func() {
|
||||||
It("populates function fields", func() {
|
It("populates function fields", func() {
|
||||||
var calledAcceptToken bool
|
var calledAddrValidation bool
|
||||||
c1 := &Config{
|
c1 := &Config{}
|
||||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||||
}
|
|
||||||
c2 := populateConfig(c1)
|
c2 := populateConfig(c1)
|
||||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||||
Expect(calledAcceptToken).To(BeTrue())
|
Expect(calledAddrValidation).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("copies non-function fields", func() {
|
It("copies non-function fields", func() {
|
||||||
|
@ -164,7 +167,7 @@ var _ = Describe("Config", func() {
|
||||||
It("populates empty fields with default values, for the server", func() {
|
It("populates empty fields with default values, for the server", func() {
|
||||||
c := populateServerConfig(&Config{})
|
c := populateServerConfig(&Config{})
|
||||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||||
Expect(c.AcceptToken).ToNot(BeNil())
|
Expect(c.RequireAddressValidation).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package tokens
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int {
|
||||||
if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil {
|
if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil {
|
||||||
panic("didn't expect connection IDs")
|
panic("didn't expect connection IDs")
|
||||||
}
|
}
|
||||||
checkAddr(token.RemoteAddr, addr)
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
|
||||||
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
|
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
|
||||||
panic("retry src conn ID doesn't match")
|
panic("retry src conn ID doesn't match")
|
||||||
}
|
}
|
||||||
checkAddr(token.RemoteAddr, addr)
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAddr(tokenAddr string, addr net.Addr) {
|
|
||||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
|
||||||
// For UDP addresses, we encode only the IP (not the port).
|
|
||||||
if ip := udpAddr.IP.String(); tokenAddr != ip {
|
|
||||||
fmt.Printf("%s vs %s", tokenAddr, ip)
|
|
||||||
panic("wrong remote address for a net.UDPAddr")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenAddr != addr.String() {
|
|
||||||
fmt.Printf("%s vs %s", tokenAddr, addr.String())
|
|
||||||
panic("wrong remote address")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -41,8 +41,8 @@ var _ = Describe("Handshake drop tests", func() {
|
||||||
HandshakeIdleTimeout: timeout,
|
HandshakeIdleTimeout: timeout,
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
})
|
})
|
||||||
if !doRetry {
|
if doRetry {
|
||||||
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
|
conf.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
}
|
}
|
||||||
var tlsConf *tls.Config
|
var tlsConf *tls.Config
|
||||||
if longCertChain {
|
if longCertChain {
|
||||||
|
|
|
@ -112,9 +112,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
|
It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
|
||||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
return true
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
runServerAndProxy()
|
||||||
_, err := quic.DialAddr(
|
_, err := quic.DialAddr(
|
||||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||||
|
@ -126,9 +124,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
|
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
|
||||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
return true
|
|
||||||
}
|
|
||||||
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
||||||
runServerAndProxy()
|
runServerAndProxy()
|
||||||
_, err := quic.DialAddr(
|
_, err := quic.DialAddr(
|
||||||
|
@ -139,21 +135,4 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
expectDurationInRTTs(2)
|
expectDurationInRTTs(2)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't complete the handshake when the server never accepts the token", func() {
|
|
||||||
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
|
||||||
getTLSClientConfig(),
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
nerr, ok := err.(net.Error)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
Expect(nerr.Timeout()).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -344,12 +344,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
if token != nil {
|
|
||||||
Expect(token.IsRetryToken).To(BeFalse())
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
var err error
|
var err error
|
||||||
// start the server, but don't call Accept
|
// start the server, but don't call Accept
|
||||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
|
@ -479,13 +474,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
|
|
||||||
Context("using tokens", func() {
|
Context("using tokens", func() {
|
||||||
It("uses tokens provided in NEW_TOKEN frames", func() {
|
It("uses tokens provided in NEW_TOKEN frames", func() {
|
||||||
tokenChan := make(chan *quic.Token, 100)
|
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
|
||||||
if token != nil && !token.IsRetryToken {
|
|
||||||
tokenChan <- token
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -509,7 +498,6 @@ var _ = Describe("Handshake tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(gets).To(Receive())
|
Expect(gets).To(Receive())
|
||||||
Eventually(puts).Should(Receive())
|
Eventually(puts).Should(Receive())
|
||||||
Expect(tokenChan).ToNot(Receive())
|
|
||||||
// received a token. Close this connection.
|
// received a token. Close this connection.
|
||||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||||
|
|
||||||
|
@ -529,17 +517,13 @@ var _ = Describe("Handshake tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.CloseWithError(0, "")
|
defer conn.CloseWithError(0, "")
|
||||||
Expect(gets).To(Receive())
|
Expect(gets).To(Receive())
|
||||||
Expect(tokenChan).To(Receive())
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
|
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
|
||||||
tokenChan := make(chan *quic.Token, 10)
|
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
|
serverConfig.MaxRetryTokenAge = time.Nanosecond
|
||||||
tokenChan <- token
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -554,18 +538,6 @@ var _ = Describe("Handshake tests", func() {
|
||||||
var transportErr *quic.TransportError
|
var transportErr *quic.TransportError
|
||||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||||
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
|
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
|
||||||
// Receiving a Retry might lead the client to measure a very small RTT.
|
|
||||||
// Then, it sometimes would retransmit the ClientHello before receiving the ServerHello.
|
|
||||||
Expect(len(tokenChan)).To(BeNumerically(">=", 2))
|
|
||||||
token := <-tokenChan
|
|
||||||
Expect(token).To(BeNil())
|
|
||||||
token = <-tokenChan
|
|
||||||
Expect(token).ToNot(BeNil())
|
|
||||||
// If the ClientHello was retransmitted, make sure that it contained the same Retry token.
|
|
||||||
for i := 2; i < len(tokenChan); i++ {
|
|
||||||
Expect(<-tokenChan).To(Equal(token))
|
|
||||||
}
|
|
||||||
Expect(token.IsRetryToken).To(BeTrue())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ var _ = Describe("Packetization", func() {
|
||||||
"localhost:0",
|
"localhost:0",
|
||||||
getTLSConfig(),
|
getTLSConfig(),
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
AcceptToken: func(net.Addr, *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
DisablePathMTUDiscovery: true,
|
DisablePathMTUDiscovery: true,
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -56,7 +56,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf := getTLSConfig()
|
tlsConf := getTLSConfig()
|
||||||
if serverConf == nil {
|
if serverConf == nil {
|
||||||
serverConf = getQuicConfig(&quic.Config{
|
serverConf = getQuicConfig(&quic.Config{
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
})
|
})
|
||||||
serverConf.Versions = []protocol.VersionNumber{version}
|
serverConf.Versions = []protocol.VersionNumber{version}
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -256,7 +256,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -453,7 +453,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
const maxStreams = 1
|
const maxStreams = 1
|
||||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||||
MaxIncomingUniStreams: maxStreams,
|
MaxIncomingUniStreams: maxStreams,
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
}))
|
}))
|
||||||
|
|
||||||
tracer := newPacketTracer()
|
tracer := newPacketTracer()
|
||||||
|
@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
MaxIncomingUniStreams: maxStreams + 1,
|
MaxIncomingUniStreams: maxStreams + 1,
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
|
@ -499,7 +499,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
const maxStreams = 42
|
const maxStreams = 42
|
||||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||||
MaxIncomingStreams: maxStreams,
|
MaxIncomingStreams: maxStreams,
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
}))
|
}))
|
||||||
|
|
||||||
tracer := newPacketTracer()
|
tracer := newPacketTracer()
|
||||||
|
@ -508,7 +508,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
MaxIncomingStreams: maxStreams - 1,
|
MaxIncomingStreams: maxStreams - 1,
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
|
@ -538,7 +538,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -560,7 +560,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||||
tracer := newPacketTracer()
|
tracer := newPacketTracer()
|
||||||
firstConf := getQuicConfig(&quic.Config{
|
firstConf := getQuicConfig(&quic.Config{
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
})
|
})
|
||||||
addFlowControlLimit(firstConf, 3)
|
addFlowControlLimit(firstConf, 3)
|
||||||
|
@ -568,7 +568,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
|
|
||||||
secondConf := getQuicConfig(&quic.Config{
|
secondConf := getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
})
|
})
|
||||||
addFlowControlLimit(secondConf, 100)
|
addFlowControlLimit(secondConf, 100)
|
||||||
|
@ -723,7 +723,7 @@ var _ = Describe("0-RTT", func() {
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(&quic.Config{
|
getQuicConfig(&quic.Config{
|
||||||
Versions: []protocol.VersionNumber{version},
|
Versions: []protocol.VersionNumber{version},
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return false },
|
||||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
30
interface.go
30
interface.go
|
@ -26,16 +26,6 @@ const (
|
||||||
Version2 = protocol.Version2
|
Version2 = protocol.Version2
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Token can be used to verify the ownership of the client address.
|
|
||||||
type Token struct {
|
|
||||||
// IsRetryToken encodes how the client received the token. There are two ways:
|
|
||||||
// * In a Retry packet sent when trying to establish a new connection.
|
|
||||||
// * In a NEW_TOKEN frame on a previous connection.
|
|
||||||
IsRetryToken bool
|
|
||||||
RemoteAddr string
|
|
||||||
SentTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ClientToken is a token received by the client.
|
// A ClientToken is a token received by the client.
|
||||||
// It can be used to skip address validation on future connection attempts.
|
// It can be used to skip address validation on future connection attempts.
|
||||||
type ClientToken struct {
|
type ClientToken struct {
|
||||||
|
@ -233,14 +223,18 @@ type Config struct {
|
||||||
// If the timeout is exceeded, the connection is closed.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
// If this value is zero, the timeout is set to 30 seconds.
|
// If this value is zero, the timeout is set to 30 seconds.
|
||||||
MaxIdleTimeout time.Duration
|
MaxIdleTimeout time.Duration
|
||||||
// AcceptToken determines if a Token is accepted.
|
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||||
// It is called with token = nil if the client didn't send a token.
|
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||||
// If not set, a default verification function is used:
|
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||||
// * it verifies that the address matches, and
|
// If not set, every client is forced to prove its remote address.
|
||||||
// * if the token is a retry token, that it was issued within the last 5 seconds
|
RequireAddressValidation func(net.Addr) bool
|
||||||
// * else, that it was issued within the last 24 hours.
|
// MaxRetryTokenAge is the maximum age of a Retry token.
|
||||||
// This option is only valid for the server.
|
// If not set, it defaults to 5 seconds. Only valid for a server.
|
||||||
AcceptToken func(clientAddr net.Addr, token *Token) bool
|
MaxRetryTokenAge time.Duration
|
||||||
|
// MaxTokenAge is the maximum age of the token presented during the handshake,
|
||||||
|
// for tokens that were issued on a previous connection.
|
||||||
|
// If not set, it defaults to 24 hours. Only valid for a server.
|
||||||
|
MaxTokenAge time.Duration
|
||||||
// The TokenStore stores tokens received from the server.
|
// The TokenStore stores tokens received from the server.
|
||||||
// Tokens are used to skip address validation on future connection attempts.
|
// Tokens are used to skip address validation on future connection attempts.
|
||||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -18,13 +19,17 @@ const (
|
||||||
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
||||||
type Token struct {
|
type Token struct {
|
||||||
IsRetryToken bool
|
IsRetryToken bool
|
||||||
RemoteAddr string
|
|
||||||
SentTime time.Time
|
SentTime time.Time
|
||||||
|
encodedRemoteAddr []byte
|
||||||
// only set for retry tokens
|
// only set for retry tokens
|
||||||
OriginalDestConnectionID protocol.ConnectionID
|
OriginalDestConnectionID protocol.ConnectionID
|
||||||
RetrySrcConnectionID protocol.ConnectionID
|
RetrySrcConnectionID protocol.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||||
|
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// token is the struct that is used for ASN1 serialization and deserialization
|
// token is the struct that is used for ASN1 serialization and deserialization
|
||||||
type token struct {
|
type token struct {
|
||||||
IsRetryToken bool
|
IsRetryToken bool
|
||||||
|
@ -102,8 +107,8 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
|
||||||
}
|
}
|
||||||
token := &Token{
|
token := &Token{
|
||||||
IsRetryToken: t.IsRetryToken,
|
IsRetryToken: t.IsRetryToken,
|
||||||
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
|
|
||||||
SentTime: time.Unix(0, t.Timestamp),
|
SentTime: time.Unix(0, t.Timestamp),
|
||||||
|
encodedRemoteAddr: t.RemoteAddr,
|
||||||
}
|
}
|
||||||
if t.IsRetryToken {
|
if t.IsRetryToken {
|
||||||
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
|
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
|
||||||
|
@ -119,16 +124,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||||
}
|
}
|
||||||
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeRemoteAddr decodes the remote address saved in the token
|
|
||||||
func decodeRemoteAddr(data []byte) string {
|
|
||||||
// data will never be empty for a token that we generated.
|
|
||||||
// Check it to be on the safe side
|
|
||||||
if len(data) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if data[0] == tokenPrefixIP {
|
|
||||||
return net.IP(data[1:]).String()
|
|
||||||
}
|
|
||||||
return string(data[1:])
|
|
||||||
}
|
|
||||||
|
|
|
@ -35,16 +35,13 @@ var _ = Describe("Token Generator", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts a valid token", func() {
|
It("accepts a valid token", func() {
|
||||||
ip := net.IPv4(192, 168, 0, 1)
|
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||||
tokenEnc, err := tokenGen.NewRetryToken(
|
tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil)
|
||||||
&net.UDPAddr{IP: ip, Port: 1337},
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(token.RemoteAddr).To(Equal("192.168.0.1"))
|
Expect(token.ValidateRemoteAddr(addr)).To(BeTrue())
|
||||||
|
Expect(token.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})).To(BeFalse())
|
||||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||||
Expect(token.OriginalDestConnectionID.Len()).To(BeZero())
|
Expect(token.OriginalDestConnectionID.Len()).To(BeZero())
|
||||||
Expect(token.RetrySrcConnectionID.Len()).To(BeZero())
|
Expect(token.RetrySrcConnectionID.Len()).To(BeZero())
|
||||||
|
@ -110,7 +107,7 @@ var _ = Describe("Token Generator", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(token.RemoteAddr).To(Equal(ip.String()))
|
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
|
||||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -121,7 +118,8 @@ var _ = Describe("Token Generator", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
token, err := tokenGen.DecodeToken(tokenEnc)
|
token, err := tokenGen.DecodeToken(tokenEnc)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337"))
|
Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue())
|
||||||
|
Expect(token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})).To(BeFalse())
|
||||||
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -44,7 +44,7 @@ func main() {
|
||||||
}
|
}
|
||||||
// a quic.Config that doesn't do a Retry
|
// a quic.Config that doesn't do a Retry
|
||||||
quicConf := &quic.Config{
|
quicConf := &quic.Config{
|
||||||
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
|
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
|
||||||
Tracer: qlog.NewTracer(getLogWriter),
|
Tracer: qlog.NewTracer(getLogWriter),
|
||||||
}
|
}
|
||||||
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
||||||
|
@ -58,15 +58,11 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch testcase {
|
switch testcase {
|
||||||
case "versionnegotiation", "handshake", "transfer", "resumption", "zerortt", "multiconnect":
|
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "zerortt", "multiconnect":
|
||||||
err = runHTTP09Server(quicConf)
|
err = runHTTP09Server(quicConf)
|
||||||
case "chacha20":
|
case "chacha20":
|
||||||
tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
|
tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
|
||||||
err = runHTTP09Server(quicConf)
|
err = runHTTP09Server(quicConf)
|
||||||
case "retry":
|
|
||||||
// By default, quic-go performs a Retry on every incoming connection.
|
|
||||||
quicConf.AcceptToken = nil
|
|
||||||
err = runHTTP09Server(quicConf)
|
|
||||||
case "http3":
|
case "http3":
|
||||||
err = runHTTP3Server(quicConf)
|
err = runHTTP3Server(quicConf)
|
||||||
default:
|
default:
|
||||||
|
|
56
server.go
56
server.go
|
@ -241,26 +241,6 @@ func (s *baseServer) run() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool {
|
|
||||||
if token == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
validity := protocol.TokenValidity
|
|
||||||
if token.IsRetryToken {
|
|
||||||
validity = protocol.RetryTokenValidity
|
|
||||||
}
|
|
||||||
if time.Now().After(token.SentTime.Add(validity)) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var sourceAddr string
|
|
||||||
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
|
|
||||||
sourceAddr = udpAddr.IP.String()
|
|
||||||
} else {
|
|
||||||
sourceAddr = clientAddr.String()
|
|
||||||
}
|
|
||||||
return sourceAddr == token.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept returns connections that already completed the handshake.
|
// Accept returns connections that already completed the handshake.
|
||||||
// It is only valid if acceptEarlyConns is false.
|
// It is only valid if acceptEarlyConns is false.
|
||||||
func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
|
func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
|
||||||
|
@ -405,33 +385,45 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
token *Token
|
token *handshake.Token
|
||||||
retrySrcConnID *protocol.ConnectionID
|
retrySrcConnID *protocol.ConnectionID
|
||||||
)
|
)
|
||||||
origDestConnID := hdr.DestConnectionID
|
origDestConnID := hdr.DestConnectionID
|
||||||
if len(hdr.Token) > 0 {
|
if len(hdr.Token) > 0 {
|
||||||
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
|
tok, err := s.tokenGenerator.DecodeToken(hdr.Token)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
token = &Token{
|
if tok.IsRetryToken {
|
||||||
IsRetryToken: c.IsRetryToken,
|
origDestConnID = tok.OriginalDestConnectionID
|
||||||
RemoteAddr: c.RemoteAddr,
|
retrySrcConnID = &tok.RetrySrcConnectionID
|
||||||
SentTime: c.SentTime,
|
|
||||||
}
|
}
|
||||||
if token.IsRetryToken {
|
token = tok
|
||||||
origDestConnID = c.OriginalDestConnectionID
|
|
||||||
retrySrcConnID = &c.RetrySrcConnectionID
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if token != nil {
|
||||||
if !s.config.AcceptToken(p.remoteAddr, token) {
|
addrIsValid := token.ValidateRemoteAddr(p.remoteAddr)
|
||||||
|
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||||
|
// We just ignore them, and act as if there was no token on this packet at all.
|
||||||
|
// This also means we might send a Retry later.
|
||||||
|
if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) {
|
||||||
|
token = nil
|
||||||
|
} else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) {
|
||||||
|
// For Retry tokens, we send an INVALID_ERROR if
|
||||||
|
// * the token is too old, or
|
||||||
|
// * the token is invalid, in case of a retry token.
|
||||||
go func() {
|
go func() {
|
||||||
defer p.buffer.Release()
|
defer p.buffer.Release()
|
||||||
if token != nil && token.IsRetryToken {
|
if token != nil && token.IsRetryToken {
|
||||||
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
|
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
|
||||||
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
|
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
||||||
|
go func() {
|
||||||
|
defer p.buffer.Release()
|
||||||
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
||||||
s.logger.Debugf("Error sending Retry: %s", err)
|
s.logger.Debugf("Error sending Retry: %s", err)
|
||||||
}
|
}
|
||||||
|
|
424
server_test.go
424
server_test.go
|
@ -126,22 +126,22 @@ var _ = Describe("Server", func() {
|
||||||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||||
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||||
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken)))
|
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
|
||||||
Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second))
|
Expect(server.config.KeepAlivePeriod).To(BeZero())
|
||||||
// stop the listener
|
// stop the listener
|
||||||
Expect(ln.Close()).To(Succeed())
|
Expect(ln.Close()).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("setups with the right values", func() {
|
It("setups with the right values", func() {
|
||||||
supportedVersions := []protocol.VersionNumber{protocol.VersionTLS}
|
supportedVersions := []protocol.VersionNumber{protocol.VersionTLS}
|
||||||
acceptToken := func(_ net.Addr, _ *Token) bool { return true }
|
requireAddrVal := func(net.Addr) bool { return true }
|
||||||
config := Config{
|
config := Config{
|
||||||
Versions: supportedVersions,
|
Versions: supportedVersions,
|
||||||
AcceptToken: acceptToken,
|
|
||||||
HandshakeIdleTimeout: 1337 * time.Hour,
|
HandshakeIdleTimeout: 1337 * time.Hour,
|
||||||
MaxIdleTimeout: 42 * time.Minute,
|
MaxIdleTimeout: 42 * time.Minute,
|
||||||
KeepAlivePeriod: 5 * time.Second,
|
KeepAlivePeriod: 5 * time.Second,
|
||||||
StatelessResetKey: []byte("foobar"),
|
StatelessResetKey: []byte("foobar"),
|
||||||
|
RequireAddressValidation: requireAddrVal,
|
||||||
}
|
}
|
||||||
ln, err := Listen(conn, tlsConf, &config)
|
ln, err := Listen(conn, tlsConf, &config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -150,7 +150,7 @@ var _ = Describe("Server", func() {
|
||||||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
|
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
|
||||||
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
||||||
Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken)))
|
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
|
||||||
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
||||||
Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar")))
|
Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar")))
|
||||||
// stop the listener
|
// stop the listener
|
||||||
|
@ -239,60 +239,8 @@ var _ = Describe("Server", func() {
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("decodes the token from the Token field", func() {
|
|
||||||
raddr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4(192, 168, 13, 37),
|
|
||||||
Port: 1337,
|
|
||||||
}
|
|
||||||
done := make(chan struct{})
|
|
||||||
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
|
|
||||||
Expect(addr).To(Equal(raddr))
|
|
||||||
Expect(token).ToNot(BeNil())
|
|
||||||
close(done)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
packet := getPacket(&wire.Header{
|
|
||||||
IsLongHeader: true,
|
|
||||||
Type: protocol.PacketTypeInitial,
|
|
||||||
Token: token,
|
|
||||||
Version: serv.config.Versions[0],
|
|
||||||
}, make([]byte, protocol.MinInitialPacketSize))
|
|
||||||
packet.remoteAddr = raddr
|
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
||||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
||||||
serv.handlePacket(packet)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("passes an empty token to the callback, if decoding fails", func() {
|
|
||||||
raddr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4(192, 168, 13, 37),
|
|
||||||
Port: 1337,
|
|
||||||
}
|
|
||||||
done := make(chan struct{})
|
|
||||||
serv.config.AcceptToken = func(addr net.Addr, token *Token) bool {
|
|
||||||
Expect(addr).To(Equal(raddr))
|
|
||||||
Expect(token).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
packet := getPacket(&wire.Header{
|
|
||||||
IsLongHeader: true,
|
|
||||||
Type: protocol.PacketTypeInitial,
|
|
||||||
Token: []byte("foobar"),
|
|
||||||
Version: serv.config.Versions[0],
|
|
||||||
}, make([]byte, protocol.MinInitialPacketSize))
|
|
||||||
packet.remoteAddr = raddr
|
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
||||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
||||||
serv.handlePacket(packet)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates a connection when the token is accepted", func() {
|
It("creates a connection when the token is accepted", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
||||||
&net.UDPAddr{},
|
&net.UDPAddr{},
|
||||||
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
|
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
|
||||||
|
@ -469,8 +417,8 @@ var _ = Describe("Server", func() {
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("replies with a Retry packet, if a Token is required", func() {
|
It("replies with a Retry packet, if a token is required", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -502,81 +450,8 @@ var _ = Describe("Server", func() {
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
It("creates a connection, if no token is required", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
hdr := &wire.Header{
|
|
||||||
IsLongHeader: true,
|
|
||||||
Type: protocol.PacketTypeInitial,
|
|
||||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
|
||||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
|
||||||
Token: token,
|
|
||||||
Version: protocol.VersionTLS,
|
|
||||||
}
|
|
||||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
||||||
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
|
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
||||||
packet.remoteAddr = raddr
|
|
||||||
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
||||||
Expect(frames).To(HaveLen(1))
|
|
||||||
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
||||||
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
|
||||||
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
||||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
|
||||||
})
|
|
||||||
done := make(chan struct{})
|
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
||||||
defer close(done)
|
|
||||||
replyHdr := parseHeader(b)
|
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
||||||
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
|
||||||
extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
||||||
ccf := f.(*wire.ConnectionCloseFrame)
|
|
||||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
|
||||||
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
|
||||||
return len(b), nil
|
|
||||||
})
|
|
||||||
serv.handlePacket(packet)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
hdr := &wire.Header{
|
|
||||||
IsLongHeader: true,
|
|
||||||
Type: protocol.PacketTypeInitial,
|
|
||||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
|
||||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
|
||||||
Token: token,
|
|
||||||
Version: protocol.VersionTLS,
|
|
||||||
}
|
|
||||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
||||||
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
|
||||||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
||||||
done := make(chan struct{})
|
|
||||||
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
|
||||||
serv.handlePacket(packet)
|
|
||||||
// make sure there are no Write calls on the packet conn
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates a connection, if no Token is required", func() {
|
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
|
||||||
hdr := &wire.Header{
|
hdr := &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -659,7 +534,7 @@ var _ = Describe("Server", func() {
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
acceptConn := make(chan struct{})
|
acceptConn := make(chan struct{})
|
||||||
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
|
@ -713,7 +588,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("only creates a single connection for a duplicate Initial", func() {
|
It("only creates a single connection for a duplicate Initial", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
var createdConn bool
|
var createdConn bool
|
||||||
conn := NewMockQuicConn(mockCtrl)
|
conn := NewMockQuicConn(mockCtrl)
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
|
@ -745,7 +620,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if the accept queue is full", func() {
|
It("rejects new connection attempts if the accept queue is full", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
|
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
|
@ -813,7 +688,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't accept new connections if they were closed in the mean time", func() {
|
It("doesn't accept new connections if they were closed in the mean time", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
|
|
||||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -877,6 +752,200 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("token validation", func() {
|
||||||
|
checkInvalidToken := func(b []byte, origHdr *wire.Header) {
|
||||||
|
replyHdr := parseHeader(b)
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
|
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
|
||||||
|
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
||||||
|
extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
f, err := wire.NewFrameParser(false, origHdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||||
|
ccf := f.(*wire.ConnectionCloseFrame)
|
||||||
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||||
|
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
||||||
|
}
|
||||||
|
|
||||||
|
It("decodes the token from the token field", func() {
|
||||||
|
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||||
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
packet := getPacket(&wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
Token: token,
|
||||||
|
Version: serv.config.Versions[0],
|
||||||
|
}, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.remoteAddr = raddr
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||||
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Token: token,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
|
||||||
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
packet.remoteAddr = raddr
|
||||||
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
|
Expect(frames).To(HaveLen(1))
|
||||||
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||||
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||||
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||||
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||||
|
})
|
||||||
|
done := make(chan struct{})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer close(done)
|
||||||
|
checkInvalidToken(b, hdr)
|
||||||
|
return len(b), nil
|
||||||
|
})
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
||||||
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
|
serv.config.MaxRetryTokenAge = time.Millisecond
|
||||||
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Token: token,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.remoteAddr = raddr
|
||||||
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
|
Expect(frames).To(HaveLen(1))
|
||||||
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||||
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||||
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||||
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
||||||
|
})
|
||||||
|
done := make(chan struct{})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer close(done)
|
||||||
|
checkInvalidToken(b, hdr)
|
||||||
|
return len(b), nil
|
||||||
|
})
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
||||||
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
|
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Token: token,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
||||||
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
packet.remoteAddr = raddr
|
||||||
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
|
done := make(chan struct{})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer close(done)
|
||||||
|
replyHdr := parseHeader(b)
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
return len(b), nil
|
||||||
|
})
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
// make sure there are no Write calls on the packet conn
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
||||||
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
|
serv.config.MaxTokenAge = time.Millisecond
|
||||||
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
token, err := serv.tokenGenerator.NewToken(raddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Token: token,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.remoteAddr = raddr
|
||||||
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
})
|
||||||
|
done := make(chan struct{})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer close(done)
|
||||||
|
return len(b), nil
|
||||||
|
})
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||||
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||||
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Token: token,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||||
|
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
||||||
|
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
done := make(chan struct{})
|
||||||
|
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
||||||
|
serv.handlePacket(packet)
|
||||||
|
// make sure there are no Write calls on the packet conn
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("accepting connections", func() {
|
Context("accepting connections", func() {
|
||||||
It("returns Accept when an error occurs", func() {
|
It("returns Accept when an error occurs", func() {
|
||||||
testErr := errors.New("test err")
|
testErr := errors.New("test err")
|
||||||
|
@ -930,7 +999,7 @@ var _ = Describe("Server", func() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background()) // handshake context
|
ctx, cancel := context.WithCancel(context.Background()) // handshake context
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -1004,7 +1073,7 @@ var _ = Describe("Server", func() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ready := make(chan struct{})
|
ready := make(chan struct{})
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -1045,7 +1114,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if the accept queue is full", func() {
|
It("rejects new connection attempts if the accept queue is full", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||||
|
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
|
@ -1106,7 +1175,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't accept new connections if they were closed in the mean time", func() {
|
It("doesn't accept new connections if they were closed in the mean time", func() {
|
||||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||||
|
|
||||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -1166,72 +1235,3 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
var _ = Describe("default source address verification", func() {
|
|
||||||
It("accepts a token", func() {
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: true,
|
|
||||||
RemoteAddr: "192.168.0.1",
|
|
||||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("requests verification if no token is provided", func() {
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects a token if the address doesn't match", func() {
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: true,
|
|
||||||
RemoteAddr: "127.0.0.1",
|
|
||||||
SentTime: time.Now(),
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a token for a remote address is not a UDP address", func() {
|
|
||||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: true,
|
|
||||||
RemoteAddr: "192.168.0.1:1337",
|
|
||||||
SentTime: time.Now(),
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects an invalid token for a remote address is not a UDP address", func() {
|
|
||||||
remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: true,
|
|
||||||
RemoteAddr: "192.168.0.1:7331", // mismatching port
|
|
||||||
SentTime: time.Now(),
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects an expired token", func() {
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: true,
|
|
||||||
RemoteAddr: "192.168.0.1",
|
|
||||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a non-retry token", func() {
|
|
||||||
Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity))
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
|
|
||||||
token := &Token{
|
|
||||||
IsRetryToken: false,
|
|
||||||
RemoteAddr: "192.168.0.1",
|
|
||||||
// if this was a retry token, it would have expired one second ago
|
|
||||||
SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second),
|
|
||||||
}
|
|
||||||
Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue