remove the RequireAddressValidation callback from the Config (#4253)

This commit is contained in:
Marten Seemann 2024-01-22 21:24:07 -08:00 committed by GitHub
parent 892851eb8c
commit a2cf43d75c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 127 additions and 128 deletions

View file

@ -2,7 +2,6 @@ package quic
import ( import (
"fmt" "fmt"
"net"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -49,16 +48,6 @@ func validateConfig(config *Config) error {
return nil return nil
} }
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
// populateConfig populates fields in the quic.Config with their default values, if none are set // populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil // it may be called with nil
func populateConfig(config *Config) *Config { func populateConfig(config *Config) *Config {
@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config {
Versions: versions, Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout, HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout, MaxIdleTimeout: idleTimeout,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod, KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow, InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow,

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"reflect" "reflect"
"time" "time"
@ -23,7 +22,7 @@ var _ = Describe("Config", func() {
}) })
It("validates a config with normal values", func() { It("validates a config with normal values", func() {
conf := populateServerConfig(&Config{ conf := populateConfig(&Config{
MaxIncomingStreams: 5, MaxIncomingStreams: 5,
MaxStreamReceiveWindow: 10, MaxStreamReceiveWindow: 10,
}) })
@ -118,19 +117,16 @@ var _ = Describe("Config", func() {
Context("cloning", func() { Context("cloning", func() {
It("clones function fields", func() { It("clones function fields", func() {
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool var calledAllowConnectionWindowIncrease, calledTracer bool
c1 := &Config{ c1 := &Config{
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer {
calledTracer = true calledTracer = true
return nil return nil
}, },
} }
c2 := c1.Clone() c2 := c1.Clone()
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234) c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
_, err := c2.GetConfigForClient(&ClientHelloInfo{}) _, err := c2.GetConfigForClient(&ClientHelloInfo{})
@ -145,29 +141,15 @@ var _ = Describe("Config", func() {
}) })
It("returns a copy", func() { It("returns a copy", func() {
c1 := &Config{ c1 := &Config{MaxIncomingStreams: 100}
MaxIncomingStreams: 100,
RequireAddressValidation: func(net.Addr) bool { return true },
}
c2 := c1.Clone() c2 := c1.Clone()
c2.MaxIncomingStreams = 200 c2.MaxIncomingStreams = 200
c2.RequireAddressValidation = func(net.Addr) bool { return false }
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
}) })
}) })
Context("populating", func() { Context("populating", func() {
It("populates function fields", func() {
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})
It("copies non-function fields", func() { It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields() c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c)).To(Equal(c)) Expect(populateConfig(c)).To(Equal(c))
@ -186,10 +168,5 @@ var _ = Describe("Config", func() {
Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse())
Expect(c.GetConfigForClient).To(BeNil()) Expect(c.GetConfigForClient).To(BeNil())
}) })
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.RequireAddressValidation).ToNot(BeNil())
})
}) })
}) })

View file

@ -118,7 +118,7 @@ var _ = Describe("Connection", func() {
srcConnID, srcConnID,
&protocol.DefaultConnectionIDGenerator{}, &protocol.DefaultConnectionIDGenerator{},
protocol.StatelessResetToken{}, protocol.StatelessResetToken{},
populateServerConfig(&Config{DisablePathMTUDiscovery: true}), populateConfig(&Config{DisablePathMTUDiscovery: true}),
&tls.Config{}, &tls.Config{},
tokenGenerator, tokenGenerator,
false, false,

View file

@ -26,23 +26,17 @@ var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.Di
type applicationProtocol struct { type applicationProtocol struct {
name string name string
run func() run func(ln *quic.Listener, port int)
} }
var _ = Describe("Handshake drop tests", func() { var _ = Describe("Handshake drop tests", func() {
var (
proxy *quicproxy.QuicProxy
ln *quic.Listener
)
data := GeneratePRData(5000) data := GeneratePRData(5000)
const timeout = 2 * time.Minute const timeout = 2 * time.Minute
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) { startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) {
conf := getQuicConfig(&quic.Config{ conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout, HandshakeIdleTimeout: timeout,
RequireAddressValidation: func(net.Addr) bool { return doRetry },
}) })
var tlsConf *tls.Config var tlsConf *tls.Config
if longCertChain { if longCertChain {
@ -50,11 +44,18 @@ var _ = Describe("Handshake drop tests", func() {
} else { } else {
tlsConf = getTLSConfig() tlsConf = getTLSConfig()
} }
var err error laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: conn}
if doRetry {
tr.MaxUnvalidatedHandshakes = -1
}
ln, err = tr.Listen(tlsConf, conf)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: dropCallback, DropPacket: dropCallback,
DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration {
@ -62,11 +63,18 @@ var _ = Describe("Handshake drop tests", func() {
}, },
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return ln, proxy.LocalPort(), func() {
ln.Close()
tr.Close()
conn.Close()
proxy.Close()
}
} }
clientSpeaksFirst := &applicationProtocol{ clientSpeaksFirst := &applicationProtocol{
name: "client speaks first", name: "client speaks first",
run: func() { run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection) serverConnChan := make(chan quic.Connection)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -82,7 +90,7 @@ var _ = Describe("Handshake drop tests", func() {
}() }()
conn, err := quic.DialAddr( conn, err := quic.DialAddr(
context.Background(), context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
@ -105,7 +113,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSpeaksFirst := &applicationProtocol{ serverSpeaksFirst := &applicationProtocol{
name: "server speaks first", name: "server speaks first",
run: func() { run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection) serverConnChan := make(chan quic.Connection)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -120,7 +128,7 @@ var _ = Describe("Handshake drop tests", func() {
}() }()
conn, err := quic.DialAddr( conn, err := quic.DialAddr(
context.Background(), context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
@ -143,7 +151,7 @@ var _ = Describe("Handshake drop tests", func() {
nobodySpeaks := &applicationProtocol{ nobodySpeaks := &applicationProtocol{
name: "nobody speaks", name: "nobody speaks",
run: func() { run: func(ln *quic.Listener, port int) {
serverConnChan := make(chan quic.Connection) serverConnChan := make(chan quic.Connection)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -153,7 +161,7 @@ var _ = Describe("Handshake drop tests", func() {
}() }()
conn, err := quic.DialAddr( conn, err := quic.DialAddr(
context.Background(), context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
@ -169,11 +177,6 @@ var _ = Describe("Handshake drop tests", func() {
}, },
} }
AfterEach(func() {
Expect(ln.Close()).To(Succeed())
Expect(proxy.Close()).To(Succeed())
})
for _, d := range directions { for _, d := range directions {
direction := d direction := d
@ -195,7 +198,7 @@ var _ = Describe("Handshake drop tests", func() {
Context(app.name, func() { Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing atomic.Int32 var incoming, outgoing atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32 var p int32
//nolint:exhaustive //nolint:exhaustive
switch d { switch d {
@ -206,12 +209,13 @@ var _ = Describe("Handshake drop tests", func() {
} }
return p == 1 && d.Is(direction) return p == 1 && d.Is(direction)
}, doRetry, longCertChain) }, doRetry, longCertChain)
app.run() defer closeFn()
app.run(ln, proxyPort)
}) })
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing atomic.Int32 var incoming, outgoing atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32 var p int32
//nolint:exhaustive //nolint:exhaustive
switch d { switch d {
@ -222,7 +226,8 @@ var _ = Describe("Handshake drop tests", func() {
} }
return p == 2 && d.Is(direction) return p == 2 && d.Is(direction)
}, doRetry, longCertChain) }, doRetry, longCertChain)
app.run() defer closeFn()
app.run(ln, proxyPort)
}) })
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
@ -230,7 +235,7 @@ var _ = Describe("Handshake drop tests", func() {
var mx sync.Mutex var mx sync.Mutex
var incoming, outgoing int var incoming, outgoing int
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
drop := mrand.Int63n(int64(3)) == 0 drop := mrand.Int63n(int64(3)) == 0
mx.Lock() mx.Lock()
@ -260,7 +265,8 @@ var _ = Describe("Handshake drop tests", func() {
} }
return drop return drop
}, doRetry, longCertChain) }, doRetry, longCertChain)
app.run() defer closeFn()
app.run(ln, proxyPort)
}) })
}) })
} }
@ -281,13 +287,14 @@ var _ = Describe("Handshake drop tests", func() {
uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b,
} }
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
if d == quicproxy.DirectionOutgoing { if d == quicproxy.DirectionOutgoing {
return false return false
} }
return mrand.Intn(3) == 0 return mrand.Intn(3) == 0
}, false, false) }, false, false)
clientSpeaksFirst.run() defer closeFn()
clientSpeaksFirst.run(ln, proxyPort)
}) })
} }
}) })

View file

@ -55,8 +55,17 @@ var _ = Describe("Handshake RTT tests", func() {
// 1 RTT for verifying the source address // 1 RTT for verifying the source address
// 1 RTT for the TLS handshake // 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() { It("is forward-secure after 2 RTTs", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
defer tr.Close()
ln, err := tr.Listen(serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer ln.Close() defer ln.Close()

View file

@ -701,14 +701,24 @@ var _ = Describe("Handshake tests", func() {
It("rejects invalid Retry token with the INVALID_TOKEN error", func() { It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
const rtt = 10 * time.Millisecond const rtt = 10 * time.Millisecond
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
// The validity period of the retry token is the handshake timeout, // The validity period of the retry token is the handshake timeout,
// which is twice the handshake idle timeout. // which is twice the handshake idle timeout.
// By setting the handshake timeout shorter than the RTT, the token will have expired by the time // By setting the handshake timeout shorter than the RTT, the token will have expired by the time
// it reaches the server. // it reaches the server.
serverConfig.HandshakeIdleTimeout = rtt / 5 serverConfig.HandshakeIdleTimeout = rtt / 5
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() defer server.Close()

View file

@ -32,7 +32,7 @@ var _ = Describe("MITM test", func() {
serverConfig *quic.Config serverConfig *quic.Config
) )
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0") addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr) c, err := net.ListenUDP("udp", addr)
@ -41,6 +41,9 @@ var _ = Describe("MITM test", func() {
Conn: c, Conn: c,
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
} }
if forceAddressValidation {
serverTransport.MaxUnvalidatedHandshakes = -1
}
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig) ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
@ -153,7 +156,7 @@ var _ = Describe("MITM test", func() {
} }
runTest := func(delayCb quicproxy.DelayCallback) { runTest := func(delayCb quicproxy.DelayCallback) {
proxyPort, closeFn := startServerAndProxy(delayCb, nil) proxyPort, closeFn := startServerAndProxy(delayCb, nil, false)
defer closeFn() defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -196,7 +199,7 @@ var _ = Describe("MITM test", func() {
}) })
runTest := func(dropCb quicproxy.DropCallback) { runTest := func(dropCb quicproxy.DropCallback) {
proxyPort, closeFn := startServerAndProxy(nil, dropCb) proxyPort, closeFn := startServerAndProxy(nil, dropCb, false)
defer closeFn() defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -310,17 +313,16 @@ var _ = Describe("MITM test", func() {
const rtt = 20 * time.Millisecond const rtt = 20 * time.Millisecond
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { runTest := func(proxyPort int) (closeFn func(), err error) {
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = clientTransport.Dial( _, err = clientTransport.Dial(
context.Background(), context.Background(),
raddr, raddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}),
) )
return func() { clientTransport.Close(); serverCloseFn() }, err return func() { clientTransport.Close() }, err
} }
// fails immediately because client connection closes when it can't find compatible version // fails immediately because client connection closes when it can't find compatible version
@ -352,7 +354,9 @@ var _ = Describe("MITM test", func() {
} }
return rtt / 2 return rtt / 2
} }
closeFn, err := runTest(delayCb) proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn() defer closeFn()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
vnErr := &quic.VersionNegotiationError{} vnErr := &quic.VersionNegotiationError{}
@ -363,8 +367,7 @@ var _ = Describe("MITM test", func() {
// times out, because client doesn't accept subsequent real retry packets from server // times out, because client doesn't accept subsequent real retry packets from server
// as it has already accepted a retry. // as it has already accepted a retry.
// TODO: determine behavior when server does not send Retry packets // TODO: determine behavior when server does not send Retry packets
It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
var initialPacketIntercepted bool var initialPacketIntercepted bool
done := make(chan struct{}) done := make(chan struct{})
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
@ -388,7 +391,9 @@ var _ = Describe("MITM test", func() {
} }
return rtt / 2 return rtt / 2
} }
closeFn, err := runTest(delayCb) proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn() defer closeFn()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
@ -418,7 +423,9 @@ var _ = Describe("MITM test", func() {
} }
return rtt return rtt
} }
closeFn, err := runTest(delayCb) proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn() defer closeFn()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
@ -448,7 +455,9 @@ var _ = Describe("MITM test", func() {
} }
return rtt return rtt
} }
closeFn, err := runTest(delayCb) proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false)
defer serverCloseFn()
closeFn, err := runTest(proxyPort)
defer closeFn() defer closeFn()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError var transportErr *quic.TransportError

View file

@ -464,14 +464,19 @@ var _ = Describe("0-RTT", func() {
} }
counter, tracer := newPacketTracer() counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly( laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
"localhost:0", Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
defer tr.Close()
ln, err := tr.ListenEarly(
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{Allow0RTT: true, Tracer: newTracer(tracer)}),
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer ln.Close() defer ln.Close()

View file

@ -267,11 +267,6 @@ type Config struct {
// If the timeout is exceeded, the connection is closed. // If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds. // If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration MaxIdleTimeout time.Duration
// RequireAddressValidation determines if a QUIC Retry packet is sent.
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool
// The TokenStore stores tokens received from the server. // The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts. // Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set // The key used to store tokens is the ServerName from the tls.Config, if set

View file

@ -37,6 +37,7 @@ func (w *responseWriter) WriteHeader(int) {}
type Server struct { type Server struct {
*http.Server *http.Server
ForceRetry bool
QuicConfig *quic.Config QuicConfig *quic.Config
mutex sync.Mutex mutex sync.Mutex
@ -68,7 +69,11 @@ func (s *Server) ListenAndServe() error {
tlsConf := s.TLSConfig.Clone() tlsConf := s.TLSConfig.Clone()
tlsConf.NextProtos = []string{h09alpn} tlsConf.NextProtos = []string{h09alpn}
ln, err := quic.ListenEarly(conn, tlsConf, s.QuicConfig) tr := quic.Transport{Conn: conn}
if s.ForceRetry {
tr.MaxUnvalidatedHandshakes = -1
}
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
"os" "os"
@ -38,9 +37,8 @@ func main() {
testcase := os.Getenv("TESTCASE") testcase := os.Getenv("TESTCASE")
quicConf := &quic.Config{ quicConf := &quic.Config{
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, Allow0RTT: testcase == "zerortt",
Allow0RTT: testcase == "zerortt", Tracer: utils.NewQLOGConnectionTracer,
Tracer: utils.NewQLOGConnectionTracer,
} }
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
if err != nil { if err != nil {
@ -54,11 +52,11 @@ func main() {
switch testcase { switch testcase {
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt": case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt":
err = runHTTP09Server(quicConf) err = runHTTP09Server(quicConf, testcase == "retry")
case "chacha20": case "chacha20":
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
defer reset() defer reset()
err = runHTTP09Server(quicConf) err = runHTTP09Server(quicConf, false)
case "http3": case "http3":
err = runHTTP3Server(quicConf) err = runHTTP3Server(quicConf)
default: default:
@ -72,12 +70,13 @@ func main() {
} }
} }
func runHTTP09Server(quicConf *quic.Config) error { func runHTTP09Server(quicConf *quic.Config, forceRetry bool) error {
server := http09.Server{ server := http09.Server{
Server: &http.Server{ Server: &http.Server{
Addr: ":443", Addr: ":443",
TLSConfig: tlsConf, TLSConfig: tlsConf,
}, },
ForceRetry: forceRetry,
QuicConfig: quicConf, QuicConfig: quicConf,
} }
http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www"))) http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www")))

View file

@ -617,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
} }
return nil return nil
} }
if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) { if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
// Retry invalidates all 0-RTT packets sent. // Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID) delete(s.zeroRTTQueues, hdr.DestConnectionID)
select { select {

View file

@ -6,7 +6,6 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"reflect"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -138,7 +137,6 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
Expect(server.config.KeepAlivePeriod).To(BeZero()) Expect(server.config.KeepAlivePeriod).To(BeZero())
// stop the listener // stop the listener
Expect(ln.Close()).To(Succeed()) Expect(ln.Close()).To(Succeed())
@ -146,13 +144,11 @@ var _ = Describe("Server", func() {
It("setups with the right values", func() { It("setups with the right values", func() {
supportedVersions := []protocol.VersionNumber{protocol.Version1} supportedVersions := []protocol.VersionNumber{protocol.Version1}
requireAddrVal := func(net.Addr) bool { return true }
config := Config{ config := Config{
Versions: supportedVersions, Versions: supportedVersions,
HandshakeIdleTimeout: 1337 * time.Hour, HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute, MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second, KeepAlivePeriod: 5 * time.Second,
RequireAddressValidation: requireAddrVal,
} }
ln, err := Listen(conn, tlsConf, &config) ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -161,7 +157,6 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
// stop the listener // stop the listener
Expect(ln.Close()).To(Succeed()) Expect(ln.Close()).To(Succeed())
@ -263,7 +258,7 @@ var _ = Describe("Server", func() {
}) })
It("creates a connection when the token is accepted", func() { It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken( retryToken, err := serv.tokenGenerator.NewRetryToken(
raddr, raddr,
@ -441,7 +436,7 @@ var _ = Describe("Server", func() {
It("replies with a Retry packet, if a token is required", func() { It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
hdr := &wire.Header{ hdr := &wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
@ -846,7 +841,7 @@ var _ = Describe("Server", func() {
}) })
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{ hdr := &wire.Header{
@ -882,7 +877,7 @@ var _ = Describe("Server", func() {
}) })
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
@ -920,7 +915,7 @@ var _ = Describe("Server", func() {
}) })
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{ hdr := &wire.Header{
@ -949,7 +944,7 @@ var _ = Describe("Server", func() {
}) })
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
serv.maxTokenAge = time.Millisecond serv.maxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr) token, err := serv.tokenGenerator.NewToken(raddr)
@ -978,7 +973,7 @@ var _ = Describe("Server", func() {
}) })
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.maxNumHandshakesUnvalidated = 0
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{ hdr := &wire.Header{
@ -1086,7 +1081,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl) conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234} conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1139,7 +1134,7 @@ var _ = Describe("Server", func() {
}) })
It("rejects a connection attempt when GetConfigClient returns an error", func() { It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any()) phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {

View file

@ -179,7 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
if t.server != nil { if t.server != nil {
return nil, errListenerAlreadySet return nil, errListenerAlreadySet
} }
conf = populateServerConfig(conf) conf = populateConfig(conf)
if err := t.init(false); err != nil { if err := t.init(false); err != nil {
return nil, err return nil, err
} }