mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
remove the RequireAddressValidation callback from the Config (#4253)
This commit is contained in:
parent
892851eb8c
commit
a2cf43d75c
14 changed files with 127 additions and 128 deletions
12
config.go
12
config.go
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateConfig(config *Config) *Config {
|
||||
|
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
|
|||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
|
@ -23,7 +22,7 @@ var _ = Describe("Config", func() {
|
|||
})
|
||||
|
||||
It("validates a config with normal values", func() {
|
||||
conf := populateServerConfig(&Config{
|
||||
conf := populateConfig(&Config{
|
||||
MaxIncomingStreams: 5,
|
||||
MaxStreamReceiveWindow: 10,
|
||||
})
|
||||
|
@ -118,19 +117,16 @@ var _ = Describe("Config", func() {
|
|||
|
||||
Context("cloning", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
|
||||
var calledAllowConnectionWindowIncrease, calledTracer bool
|
||||
c1 := &Config{
|
||||
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
|
||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||
Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer {
|
||||
calledTracer = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
c2.AllowConnectionWindowIncrease(nil, 1234)
|
||||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
|
||||
|
@ -145,29 +141,15 @@ var _ = Describe("Config", func() {
|
|||
})
|
||||
|
||||
It("returns a copy", func() {
|
||||
c1 := &Config{
|
||||
MaxIncomingStreams: 100,
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
}
|
||||
c1 := &Config{MaxIncomingStreams: 100}
|
||||
c2 := c1.Clone()
|
||||
c2.MaxIncomingStreams = 200
|
||||
c2.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
|
||||
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
||||
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("populating", func() {
|
||||
It("populates function fields", func() {
|
||||
var calledAddrValidation bool
|
||||
c1 := &Config{}
|
||||
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||
c2 := populateConfig(c1)
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
})
|
||||
|
||||
It("copies non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(populateConfig(c)).To(Equal(c))
|
||||
|
@ -186,10 +168,5 @@ var _ = Describe("Config", func() {
|
|||
Expect(c.DisablePathMTUDiscovery).To(BeFalse())
|
||||
Expect(c.GetConfigForClient).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates empty fields with default values, for the server", func() {
|
||||
c := populateServerConfig(&Config{})
|
||||
Expect(c.RequireAddressValidation).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -118,7 +118,7 @@ var _ = Describe("Connection", func() {
|
|||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
protocol.StatelessResetToken{},
|
||||
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
|
||||
populateConfig(&Config{DisablePathMTUDiscovery: true}),
|
||||
&tls.Config{},
|
||||
tokenGenerator,
|
||||
false,
|
||||
|
|
|
@ -26,23 +26,17 @@ var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.Di
|
|||
|
||||
type applicationProtocol struct {
|
||||
name string
|
||||
run func()
|
||||
run func(ln *quic.Listener, port int)
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake drop tests", func() {
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
ln *quic.Listener
|
||||
)
|
||||
|
||||
data := GeneratePRData(5000)
|
||||
const timeout = 2 * time.Minute
|
||||
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) {
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
|
||||
conf := getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
RequireAddressValidation: func(net.Addr) bool { return doRetry },
|
||||
MaxIdleTimeout: timeout,
|
||||
HandshakeIdleTimeout: timeout,
|
||||
})
|
||||
var tlsConf *tls.Config
|
||||
if longCertChain {
|
||||
|
@ -50,11 +44,18 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
} else {
|
||||
tlsConf = getTLSConfig()
|
||||
}
|
||||
var err error
|
||||
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
if doRetry {
|
||||
tr.MaxUnvalidatedHandshakes = -1
|
||||
}
|
||||
ln, err = tr.Listen(tlsConf, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: dropCallback,
|
||||
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
|
||||
|
@ -62,11 +63,18 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return ln, proxy.LocalPort(), func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
conn.Close()
|
||||
proxy.Close()
|
||||
}
|
||||
}
|
||||
|
||||
clientSpeaksFirst := &applicationProtocol{
|
||||
name: "client speaks first",
|
||||
run: func() {
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -82,7 +90,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
|
@ -105,7 +113,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
|
||||
serverSpeaksFirst := &applicationProtocol{
|
||||
name: "server speaks first",
|
||||
run: func() {
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -120,7 +128,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
|
@ -143,7 +151,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
|
||||
nobodySpeaks := &applicationProtocol{
|
||||
name: "nobody speaks",
|
||||
run: func() {
|
||||
run: func(ln *quic.Listener, port int) {
|
||||
serverConnChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -153,7 +161,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
fmt.Sprintf("localhost:%d", port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: timeout,
|
||||
|
@ -169,11 +177,6 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
},
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
|
@ -195,7 +198,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
Context(app.name, func() {
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing atomic.Int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
//nolint:exhaustive
|
||||
switch d {
|
||||
|
@ -206,12 +209,13 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}
|
||||
return p == 1 && d.Is(direction)
|
||||
}, doRetry, longCertChain)
|
||||
app.run()
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
|
||||
var incoming, outgoing atomic.Int32
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
var p int32
|
||||
//nolint:exhaustive
|
||||
switch d {
|
||||
|
@ -222,7 +226,8 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}
|
||||
return p == 2 && d.Is(direction)
|
||||
}, doRetry, longCertChain)
|
||||
app.run()
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
|
||||
|
@ -230,7 +235,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
var mx sync.Mutex
|
||||
var incoming, outgoing int
|
||||
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
drop := mrand.Int63n(int64(3)) == 0
|
||||
|
||||
mx.Lock()
|
||||
|
@ -260,7 +265,8 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
}
|
||||
return drop
|
||||
}, doRetry, longCertChain)
|
||||
app.run()
|
||||
defer closeFn()
|
||||
app.run(ln, proxyPort)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -281,13 +287,14 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
|
||||
}
|
||||
|
||||
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
|
||||
if d == quicproxy.DirectionOutgoing {
|
||||
return false
|
||||
}
|
||||
return mrand.Intn(3) == 0
|
||||
}, false, false)
|
||||
clientSpeaksFirst.run()
|
||||
defer closeFn()
|
||||
clientSpeaksFirst.run(ln, proxyPort)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
|
@ -55,8 +55,17 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs", func() {
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(serverTLSConfig, serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
|
|
|
@ -701,14 +701,24 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
|
||||
const rtt = 10 * time.Millisecond
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
|
||||
// The validity period of the retry token is the handshake timeout,
|
||||
// which is twice the handshake idle timeout.
|
||||
// By setting the handshake timeout shorter than the RTT, the token will have expired by the time
|
||||
// it reaches the server.
|
||||
serverConfig.HandshakeIdleTimeout = rtt / 5
|
||||
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ var _ = Describe("MITM test", func() {
|
|||
serverConfig *quic.Config
|
||||
)
|
||||
|
||||
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
|
||||
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c, err := net.ListenUDP("udp", addr)
|
||||
|
@ -41,6 +41,9 @@ var _ = Describe("MITM test", func() {
|
|||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
if forceAddressValidation {
|
||||
serverTransport.MaxUnvalidatedHandshakes = -1
|
||||
}
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
|
@ -153,7 +156,7 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
|
||||
runTest := func(delayCb quicproxy.DelayCallback) {
|
||||
proxyPort, closeFn := startServerAndProxy(delayCb, nil)
|
||||
proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -196,7 +199,7 @@ var _ = Describe("MITM test", func() {
|
|||
})
|
||||
|
||||
runTest := func(dropCb quicproxy.DropCallback) {
|
||||
proxyPort, closeFn := startServerAndProxy(nil, dropCb)
|
||||
proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -310,17 +313,16 @@ var _ = Describe("MITM test", func() {
|
|||
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
|
||||
runTest := func(proxyPort int) (closeFn func(), err error) {
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
|
||||
)
|
||||
return func() { clientTransport.Close(); serverCloseFn() }, err
|
||||
return func() { clientTransport.Close() }, err
|
||||
}
|
||||
|
||||
// fails immediately because client connection closes when it can't find compatible version
|
||||
|
@ -352,7 +354,9 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
return rtt / 2
|
||||
}
|
||||
closeFn, err := runTest(delayCb)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
vnErr := &quic.VersionNegotiationError{}
|
||||
|
@ -363,8 +367,7 @@ var _ = Describe("MITM test", func() {
|
|||
// times out, because client doesn't accept subsequent real retry packets from server
|
||||
// as it has already accepted a retry.
|
||||
// TODO: determine behavior when server does not send Retry packets
|
||||
It("fails when a forged Retry packet with modified srcConnID is sent to client", func() {
|
||||
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
|
||||
var initialPacketIntercepted bool
|
||||
done := make(chan struct{})
|
||||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
|
@ -388,7 +391,9 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
return rtt / 2
|
||||
}
|
||||
closeFn, err := runTest(delayCb)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(net.Error).Timeout()).To(BeTrue())
|
||||
|
@ -418,7 +423,9 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
return rtt
|
||||
}
|
||||
closeFn, err := runTest(delayCb)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(net.Error).Timeout()).To(BeTrue())
|
||||
|
@ -448,7 +455,9 @@ var _ = Describe("MITM test", func() {
|
|||
}
|
||||
return rtt
|
||||
}
|
||||
closeFn, err := runTest(delayCb)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
|
||||
defer serverCloseFn()
|
||||
closeFn, err := runTest(proxyPort)
|
||||
defer closeFn()
|
||||
Expect(err).To(HaveOccurred())
|
||||
var transportErr *quic.TransportError
|
||||
|
|
|
@ -464,14 +464,19 @@ var _ = Describe("0-RTT", func() {
|
|||
}
|
||||
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.ListenEarly(
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
getQuicConfig(&quic.Config{Allow0RTT: true, Tracer: newTracer(tracer)}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
|
|
@ -267,11 +267,6 @@ type Config struct {
|
|||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||
// If not set, every client is forced to prove its remote address.
|
||||
RequireAddressValidation func(net.Addr) bool
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
|
|
|
@ -37,6 +37,7 @@ func (w *responseWriter) WriteHeader(int) {}
|
|||
type Server struct {
|
||||
*http.Server
|
||||
|
||||
ForceRetry bool
|
||||
QuicConfig *quic.Config
|
||||
|
||||
mutex sync.Mutex
|
||||
|
@ -68,7 +69,11 @@ func (s *Server) ListenAndServe() error {
|
|||
|
||||
tlsConf := s.TLSConfig.Clone()
|
||||
tlsConf.NextProtos = []string{h09alpn}
|
||||
ln, err := quic.ListenEarly(conn, tlsConf, s.QuicConfig)
|
||||
tr := quic.Transport{Conn: conn}
|
||||
if s.ForceRetry {
|
||||
tr.MaxUnvalidatedHandshakes = -1
|
||||
}
|
||||
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
|
@ -38,9 +37,8 @@ func main() {
|
|||
testcase := os.Getenv("TESTCASE")
|
||||
|
||||
quicConf := &quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
|
||||
Allow0RTT: testcase == "zerortt",
|
||||
Tracer: utils.NewQLOGConnectionTracer,
|
||||
Allow0RTT: testcase == "zerortt",
|
||||
Tracer: utils.NewQLOGConnectionTracer,
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
||||
if err != nil {
|
||||
|
@ -54,11 +52,11 @@ func main() {
|
|||
|
||||
switch testcase {
|
||||
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt":
|
||||
err = runHTTP09Server(quicConf)
|
||||
err = runHTTP09Server(quicConf, testcase == "retry")
|
||||
case "chacha20":
|
||||
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
|
||||
defer reset()
|
||||
err = runHTTP09Server(quicConf)
|
||||
err = runHTTP09Server(quicConf, false)
|
||||
case "http3":
|
||||
err = runHTTP3Server(quicConf)
|
||||
default:
|
||||
|
@ -72,12 +70,13 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func runHTTP09Server(quicConf *quic.Config) error {
|
||||
func runHTTP09Server(quicConf *quic.Config, forceRetry bool) error {
|
||||
server := http09.Server{
|
||||
Server: &http.Server{
|
||||
Addr: ":443",
|
||||
TLSConfig: tlsConf,
|
||||
},
|
||||
ForceRetry: forceRetry,
|
||||
QuicConfig: quicConf,
|
||||
}
|
||||
http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www")))
|
||||
|
|
|
@ -617,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
}
|
||||
return nil
|
||||
}
|
||||
if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) {
|
||||
if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
|
||||
// Retry invalidates all 0-RTT packets sent.
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -138,7 +137,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
|
||||
Expect(server.config.KeepAlivePeriod).To(BeZero())
|
||||
// stop the listener
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
|
@ -146,13 +144,11 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("setups with the right values", func() {
|
||||
supportedVersions := []protocol.VersionNumber{protocol.Version1}
|
||||
requireAddrVal := func(net.Addr) bool { return true }
|
||||
config := Config{
|
||||
Versions: supportedVersions,
|
||||
HandshakeIdleTimeout: 1337 * time.Hour,
|
||||
MaxIdleTimeout: 42 * time.Minute,
|
||||
KeepAlivePeriod: 5 * time.Second,
|
||||
RequireAddressValidation: requireAddrVal,
|
||||
Versions: supportedVersions,
|
||||
HandshakeIdleTimeout: 1337 * time.Hour,
|
||||
MaxIdleTimeout: 42 * time.Minute,
|
||||
KeepAlivePeriod: 5 * time.Second,
|
||||
}
|
||||
ln, err := Listen(conn, tlsConf, &config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -161,7 +157,6 @@ var _ = Describe("Server", func() {
|
|||
Expect(server.config.Versions).To(Equal(supportedVersions))
|
||||
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
|
||||
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
||||
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
|
||||
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
||||
// stop the listener
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
|
@ -263,7 +258,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("creates a connection when the token is accepted", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
||||
raddr,
|
||||
|
@ -441,7 +436,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("replies with a Retry packet, if a token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
|
@ -846,7 +841,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -882,7 +877,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
|
||||
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
|
@ -920,7 +915,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -949,7 +944,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
serv.maxTokenAge = time.Millisecond
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
token, err := serv.tokenGenerator.NewToken(raddr)
|
||||
|
@ -978,7 +973,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
serv.maxNumHandshakesUnvalidated = 0
|
||||
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr := &wire.Header{
|
||||
|
@ -1086,7 +1081,7 @@ var _ = Describe("Server", func() {
|
|||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
conf := &Config{MaxIncomingStreams: 1234}
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
||||
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -1139,7 +1134,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("rejects a connection attempt when GetConfigClient returns an error", func() {
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
||||
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
|
|
|
@ -179,7 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
|||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue