sync: quic-go 0.42.0

Signed-off-by: Gaukas Wang <i@gaukas.wang>
This commit is contained in:
Gaukas Wang 2024-04-23 22:34:55 -06:00
parent d40dde9b9b
commit 4973374ea5
No known key found for this signature in database
GPG key ID: 6F0DF52D710D8189
252 changed files with 13121 additions and 5437 deletions

View file

@ -5,13 +5,14 @@ import (
"crypto/rand"
"errors"
"net"
"reflect"
"sync"
"sync/atomic"
"time"
tls "github.com/refraction-networking/utls"
"golang.org/x/time/rate"
"github.com/refraction-networking/uquic/internal/handshake"
mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging"
"github.com/refraction-networking/uquic/internal/protocol"
@ -50,6 +51,7 @@ var _ = Describe("Server", func() {
data = data[:len(data)+16]
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
return receivedPacket{
rcvTime: time.Now(),
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
data: data,
buffer: buf,
@ -84,6 +86,25 @@ var _ = Describe("Server", func() {
return hdr
}
checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.IsApplicationError).To(BeFalse())
Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
Expect(ccf.ReasonPhrase).To(BeEmpty())
}
BeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
@ -92,9 +113,10 @@ var _ = Describe("Server", func() {
<-wait
return 0, nil, errors.New("done")
}).MaxTimes(1)
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
close(wait)
conn.EXPECT().SetReadDeadline(time.Time{})
return nil
}).MaxTimes(1)
tlsConf = testdata.GetTLSConfig()
tlsConf.NextProtos = []string{"proto1"}
@ -107,8 +129,8 @@ var _ = Describe("Server", func() {
})
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
version := protocol.Version(0x1234)
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
@ -119,21 +141,18 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
Expect(server.config.RequireAddressValidation).ToNot(BeNil())
Expect(server.config.KeepAlivePeriod).To(BeZero())
// stop the listener
Expect(ln.Close()).To(Succeed())
})
It("setups with the right values", func() {
supportedVersions := []protocol.VersionNumber{protocol.Version1}
requireAddrVal := func(net.Addr) bool { return true }
supportedVersions := []protocol.Version{protocol.Version1}
config := Config{
Versions: supportedVersions,
HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second,
RequireAddressValidation: requireAddrVal,
Versions: supportedVersions,
HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second,
}
ln, err := Listen(conn, tlsConf, &config)
Expect(err).ToNot(HaveOccurred())
@ -142,7 +161,6 @@ var _ = Describe("Server", func() {
Expect(server.config.Versions).To(Equal(supportedVersions))
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
// stop the listener
Expect(ln.Close()).To(Succeed())
@ -189,6 +207,7 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
tracer.EXPECT().Close()
tr.Close()
})
@ -244,7 +263,7 @@ var _ = Describe("Server", func() {
})
It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken(
raddr,
@ -267,17 +286,6 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
_, ok := fn()
return ok
})
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
@ -296,7 +304,7 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
@ -305,14 +313,20 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
Expect(srcConnID).To(Equal(newConnID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() { close(run) })
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
phm.EXPECT().Get(connID)
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
Expect(cid).To(Equal(newConnID))
return true
})
done := make(chan struct{})
go func() {
@ -326,6 +340,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
@ -339,7 +355,7 @@ var _ = Describe("Server", func() {
}, make([]byte, protocol.MinUnknownVersionPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) {
tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
})
@ -351,7 +367,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42)))
Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
return len(b), nil
})
serv.handlePacket(packet)
@ -371,7 +387,6 @@ var _ = Describe("Server", func() {
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).Do(func() { close(done) }).Times(0)
serv.handlePacket(packet)
Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
})
@ -380,7 +395,7 @@ var _ = Describe("Server", func() {
data := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
[]protocol.VersionNumber{1, 2, 3},
[]protocol.Version{1, 2, 3},
)
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
@ -421,7 +436,13 @@ var _ = Describe("Server", func() {
It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
var called bool
serv.verifySourceAddress = func(addr net.Addr) bool {
Expect(addr).To(Equal(raddr))
called = true
return true
}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
@ -429,7 +450,6 @@ var _ = Describe("Server", func() {
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
@ -451,6 +471,7 @@ var _ = Describe("Server", func() {
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
Expect(called).To(BeTrue())
})
It("creates a connection, if no token is required", func() {
@ -467,19 +488,6 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
_, ok := fn()
return ok
}),
)
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
@ -498,7 +506,7 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
Expect(retrySrcConnID).To(BeNil())
@ -507,14 +515,22 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
Expect(srcConnID).To(Equal(newConnID))
newConnID = srcConnID
Expect(tokenP).To(Equal(token))
conn.EXPECT().handlePacket(p)
conn.EXPECT().run().Do(func() { close(run) })
conn.EXPECT().run().Do(func() error { close(run); return nil })
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
}),
)
done := make(chan struct{})
go func() {
@ -528,18 +544,19 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
})
It("drops packets if the receive queue is full", func() {
serv.verifySourceAddress = func(net.Addr) bool { return false }
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
}).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually
var counter atomic.Uint32
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -557,15 +574,17 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
<-acceptConn
atomic.AddUint32(&counter, 1)
counter.Add(1)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
return conn
}
@ -585,14 +604,14 @@ var _ = Describe("Server", func() {
close(acceptConn)
Eventually(
func() uint32 { return atomic.LoadUint32(&counter) },
func() uint32 { return counter.Load() },
scaleDuration(100*time.Millisecond),
).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
})
It("only creates a single connection for a duplicate Initial", func() {
var createdConn bool
done := make(chan struct{})
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -610,25 +629,38 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
createdConn = true
return NewMockQUICConn(mockCtrl)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) {
close(done)
})
return conn
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
It("limits the number of unvalidated handshakes", func() {
const limit = 3
limiter := rate.NewLimiter(0, limit)
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
connChan := make(chan *MockQUICConn, 1)
var wg sync.WaitGroup
wg.Add(limit)
done := make(chan struct{})
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -646,63 +678,53 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn := <-connChan
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done })
return conn
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(getInitialWithRandomDestConnID())
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
}()
// Initiate the maximum number of allowed connection attempts.
for i := 0; i < limit; i++ {
conn := NewMockQUICConn(mockCtrl)
connChan <- conn
serv.handlePacket(getInitialWithRandomDestConnID())
}
wg.Wait()
// Now initiate another connection attempt.
p := getInitialWithRandomDestConnID()
hdr, _, _, err := wire.ParsePacket(p.data)
Expect(err).ToNot(HaveOccurred())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
defer GinkgoRecover()
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer GinkgoRecover()
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
hdr, _, _, err := wire.ParsePacket(b)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
serv.handlePacket(p)
Eventually(done).Should(BeClosed())
})
It("doesn't accept new connections if they were closed in the mean time", func() {
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{})
conn := NewMockQUICConn(mockCtrl)
for i := 0; i < limit; i++ {
_, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
}
wg.Wait()
})
})
Context("token validation", func() {
It("decodes the token from the token field", func() {
serv.newConn = func(
_ sendConn,
runner connRunner,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
@ -717,67 +739,17 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run()
conn.EXPECT().Context().Return(ctx)
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
close(connCreated)
return conn
c := NewMockQUICConn(mockCtrl)
c.EXPECT().handlePacket(gomock.Any())
c.EXPECT().run()
c.EXPECT().HandshakeComplete()
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.EXPECT().Context().Return(ctx)
return c
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(connCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept(context.Background())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})
Context("token validation", func() {
checkInvalidToken := func(b []byte, origHdr *wire.Header) {
replyHdr := parseHeader(b)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
}
It("decodes the token from the token field", func() {
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
@ -792,13 +764,18 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
close(done)
return true
})
phm.EXPECT().Remove(gomock.Any()).AnyTimes()
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
@ -825,7 +802,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkInvalidToken(b, hdr)
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
@ -834,7 +811,7 @@ var _ = Describe("Server", func() {
})
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
@ -863,7 +840,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
checkInvalidToken(b, hdr)
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
@ -872,7 +849,7 @@ var _ = Describe("Server", func() {
})
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
@ -901,7 +878,7 @@ var _ = Describe("Server", func() {
})
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.maxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr)
@ -930,7 +907,6 @@ var _ = Describe("Server", func() {
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
@ -954,30 +930,66 @@ var _ = Describe("Server", func() {
})
Context("accepting connections", func() {
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")
It("returns Accept when closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
close(done)
}()
serv.setCloseError(testErr)
serv.Close()
Eventually(done).Should(BeClosed())
serv.onClose() // shutdown
})
It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
serv.Close()
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
}
serv.onClose() // shutdown
})
PIt("closes connection that are still handshaking after Close", func() {
serv.Close()
destroyed := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Eventually(destroyed).Should(BeClosed())
})
It("returns when the context is canceled", func() {
@ -999,7 +1011,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -1027,21 +1039,18 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1052,13 +1061,9 @@ var _ = Describe("Server", func() {
})
It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
_, ok := fn()
return ok
})
done := make(chan struct{})
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
@ -1104,20 +1109,17 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1177,20 +1179,17 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run().Do(func() {})
conn.EXPECT().run()
conn.EXPECT().earlyConnReady().Return(ready)
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1201,8 +1200,9 @@ var _ = Describe("Server", func() {
})
It("rejects new connection attempts if the accept queue is full", func() {
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
connChan := make(chan *MockQUICConn, 1)
var wg sync.WaitGroup // to make sure the test fully completes
wg.Add(protocol.MaxAcceptQueueSize + 1)
serv.baseServer.newConn = func(
_ sendConn,
runner connRunner,
@ -1220,11 +1220,12 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
defer wg.Done()
ready := make(chan struct{})
close(ready)
conn := NewMockQUICConn(mockCtrl)
conn := <-connChan
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run()
conn.EXPECT().earlyConnReady().Return(ready)
@ -1233,33 +1234,23 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn := NewMockQUICConn(mockCtrl)
connChan <- conn
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
}
Eventually(func() int32 { return atomic.LoadInt32(&serv.baseServer.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
p := getInitialWithRandomDestConnID()
hdr := parseHeader(p.data)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
return len(b), nil
})
serv.baseServer.handlePacket(p)
Eventually(done).Should(BeClosed())
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().closeWithTransportError(ConnectionRefused)
connChan <- conn
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
wg.Wait()
})
It("doesn't accept new connections if they were closed in the mean time", func() {
@ -1284,7 +1275,7 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run()
@ -1295,11 +1286,8 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
@ -1316,7 +1304,6 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
@ -1342,8 +1329,8 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
tr.Close()
tracer.EXPECT().Close()
Expect(tr.Close()).To(Succeed())
})
It("passes packets to existing connections", func() {
@ -1410,7 +1397,7 @@ var _ = Describe("Server", func() {
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
_ protocol.Version,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
var calls []any
@ -1423,15 +1410,14 @@ var _ = Describe("Server", func() {
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().closeWithTransportError(gomock.Any())
return conn
}
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())
})