use typed atomics in integration tests (#4120)

* use typed atomic in integration tests

* use an atomic.Bool in hotswap test
This commit is contained in:
Marten Seemann 2023-10-25 11:46:29 +07:00 committed by GitHub
parent 6239effc7a
commit 30f9c0139f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 118 additions and 120 deletions

View file

@ -31,7 +31,7 @@ var _ = Describe("Stream Cancellations", func() {
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter atomic.Int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var wg sync.WaitGroup var wg sync.WaitGroup
@ -50,18 +50,18 @@ var _ = Describe("Stream Cancellations", func() {
ErrorCode: quic.StreamErrorCode(str.StreamID()), ErrorCode: quic.StreamErrorCode(str.StreamID()),
Remote: true, Remote: true,
})) }))
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
return return
} }
if err := str.Close(); err != nil { if err := str.Close(); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID()))) Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID())))
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
return return
} }
}() }()
} }
wg.Wait() wg.Wait()
numCanceledStreamsChan <- atomic.LoadInt32(&canceledCounter) numCanceledStreamsChan <- canceledCounter.Load()
}() }()
return numCanceledStreamsChan return numCanceledStreamsChan
} }
@ -80,7 +80,7 @@ var _ = Describe("Stream Cancellations", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter atomic.Int32
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
@ -91,7 +91,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel around 2/3 of the streams // cancel around 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
resetErr := quic.StreamErrorCode(str.StreamID()) resetErr := quic.StreamErrorCode(str.StreamID())
str.CancelRead(resetErr) str.CancelRead(resetErr)
_, err := str.Read([]byte{0}) _, err := str.Read([]byte{0})
@ -113,7 +113,7 @@ var _ = Describe("Stream Cancellations", func() {
Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter))
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
clientCanceledCounter := atomic.LoadInt32(&canceledCounter) clientCanceledCounter := canceledCounter.Load()
// The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data.
Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter)) Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter))
fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams) fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams)
@ -132,7 +132,7 @@ var _ = Describe("Stream Cancellations", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter atomic.Int32
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
@ -148,7 +148,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.CancelRead(quic.StreamErrorCode(str.StreamID())) str.CancelRead(quic.StreamErrorCode(str.StreamID()))
Expect(data).To(Equal(PRData[:length])) Expect(data).To(Equal(PRData[:length]))
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
return return
} }
data, err := io.ReadAll(str) data, err := io.ReadAll(str)
@ -162,7 +162,7 @@ var _ = Describe("Stream Cancellations", func() {
Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter))
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
clientCanceledCounter := atomic.LoadInt32(&canceledCounter) clientCanceledCounter := canceledCounter.Load()
// The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data.
Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter)) Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter))
fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams) fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams)
@ -185,7 +185,7 @@ var _ = Describe("Stream Cancellations", func() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
var counter int32 var counter atomic.Int32
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -199,7 +199,7 @@ var _ = Describe("Stream Cancellations", func() {
defer close(done) defer close(done)
b := make([]byte, 32) b := make([]byte, 32)
if _, err := str.Read(b); err != nil { if _, err := str.Read(b); err != nil {
atomic.AddInt32(&counter, 1) counter.Add(1)
Expect(err).To(Equal(&quic.StreamError{ Expect(err).To(Equal(&quic.StreamError{
StreamID: str.StreamID(), StreamID: str.StreamID(),
ErrorCode: 1234, ErrorCode: 1234,
@ -214,7 +214,7 @@ var _ = Describe("Stream Cancellations", func() {
} }
wg.Wait() wg.Wait()
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
numCanceled := atomic.LoadInt32(&counter) numCanceled := counter.Load()
fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams) fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams)
Expect(numCanceled).ToNot(BeZero()) Expect(numCanceled).ToNot(BeZero())
Eventually(serverCanceledCounterChan).Should(Receive()) Eventually(serverCanceledCounterChan).Should(Receive())
@ -232,7 +232,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var wg sync.WaitGroup var wg sync.WaitGroup
var counter int32 var counter atomic.Int32
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -242,7 +242,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str) data, err := io.ReadAll(str)
if err != nil { if err != nil {
atomic.AddInt32(&counter, 1) counter.Add(1)
Expect(err).To(MatchError(&quic.StreamError{ Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(), StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()), ErrorCode: quic.StreamErrorCode(str.StreamID()),
@ -254,7 +254,7 @@ var _ = Describe("Stream Cancellations", func() {
} }
wg.Wait() wg.Wait()
streamCount := atomic.LoadInt32(&counter) streamCount := counter.Load()
fmt.Fprintf(GinkgoWriter, "Canceled writing on %d of %d streams\n", streamCount, numStreams) fmt.Fprintf(GinkgoWriter, "Canceled writing on %d of %d streams\n", streamCount, numStreams)
Expect(streamCount).To(BeNumerically(">", numStreams/10)) Expect(streamCount).To(BeNumerically(">", numStreams/10))
Expect(numStreams - streamCount).To(BeNumerically(">", numStreams/10)) Expect(numStreams - streamCount).To(BeNumerically(">", numStreams/10))
@ -267,7 +267,7 @@ var _ = Describe("Stream Cancellations", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter atomic.Int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
conn, err := server.Accept(context.Background()) conn, err := server.Accept(context.Background())
@ -280,7 +280,7 @@ var _ = Describe("Stream Cancellations", func() {
// cancel about 2/3 of the streams // cancel about 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
str.CancelWrite(quic.StreamErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
return return
} }
_, err = str.Write(PRData) _, err = str.Write(PRData)
@ -291,14 +291,14 @@ var _ = Describe("Stream Cancellations", func() {
}() }()
clientCanceledStreams := runClient(server) clientCanceledStreams := runClient(server)
Expect(clientCanceledStreams).To(Equal(atomic.LoadInt32(&canceledCounter))) Expect(clientCanceledStreams).To(Equal(canceledCounter.Load()))
}) })
It("downloads when the server cancels some streams after sending some data", func() { It("downloads when the server cancels some streams after sending some data", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter atomic.Int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
conn, err := server.Accept(context.Background()) conn, err := server.Accept(context.Background())
@ -314,7 +314,7 @@ var _ = Describe("Stream Cancellations", func() {
_, err = str.Write(PRData[:length]) _, err = str.Write(PRData[:length])
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.CancelWrite(quic.StreamErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
atomic.AddInt32(&canceledCounter, 1) canceledCounter.Add(1)
return return
} }
_, err = str.Write(PRData) _, err = str.Write(PRData)
@ -325,7 +325,7 @@ var _ = Describe("Stream Cancellations", func() {
}() }()
clientCanceledStreams := runClient(server) clientCanceledStreams := runClient(server)
Expect(clientCanceledStreams).To(Equal(atomic.LoadInt32(&canceledCounter))) Expect(clientCanceledStreams).To(Equal(canceledCounter.Load()))
}) })
}) })
@ -378,7 +378,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var wg sync.WaitGroup var wg sync.WaitGroup
var counter int32 var counter atomic.Int32
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -399,13 +399,13 @@ var _ = Describe("Stream Cancellations", func() {
})) }))
return return
} }
atomic.AddInt32(&counter, 1) counter.Add(1)
Expect(data).To(Equal(PRData)) Expect(data).To(Equal(PRData))
}() }()
} }
wg.Wait() wg.Wait()
count := atomic.LoadInt32(&counter) count := counter.Load()
Expect(count).To(BeNumerically(">", numStreams/15)) Expect(count).To(BeNumerically(">", numStreams/15))
fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams)
@ -464,7 +464,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var wg sync.WaitGroup var wg sync.WaitGroup
var counter int32 var counter atomic.Int32
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -495,14 +495,14 @@ var _ = Describe("Stream Cancellations", func() {
return return
} }
atomic.AddInt32(&counter, 1) counter.Add(1)
Expect(data).To(Equal(PRData)) Expect(data).To(Equal(PRData))
}() }()
} }
wg.Wait() wg.Wait()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
count := atomic.LoadInt32(&counter) count := counter.Load()
Expect(count).To(BeNumerically(">", numStreams/15)) Expect(count).To(BeNumerically(">", numStreams/15))
fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams)
@ -543,7 +543,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var numToAccept int var numToAccept int
var counter int32 var counter atomic.Int32
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
for numToAccept < numStreams { for numToAccept < numStreams {
@ -561,7 +561,7 @@ var _ = Describe("Stream Cancellations", func() {
str, err := conn.AcceptUniStream(ctx) str, err := conn.AcceptUniStream(ctx)
if err != nil { if err != nil {
if err.Error() == "context canceled" { if err.Error() == "context canceled" {
atomic.AddInt32(&counter, 1) counter.Add(1)
} }
return return
} }
@ -573,7 +573,7 @@ var _ = Describe("Stream Cancellations", func() {
} }
wg.Wait() wg.Wait()
count := atomic.LoadInt32(&counter) count := counter.Load()
fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count) fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count)
Expect(count).To(BeNumerically(">", numStreams/2)) Expect(count).To(BeNumerically(">", numStreams/2))
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
@ -589,7 +589,7 @@ var _ = Describe("Stream Cancellations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
msg := make(chan struct{}, 1) msg := make(chan struct{}, 1)
var numCanceled int32 var numCanceled atomic.Int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer close(msg) defer close(msg)
@ -603,7 +603,7 @@ var _ = Describe("Stream Cancellations", func() {
str, err := conn.OpenUniStreamSync(ctx) str, err := conn.OpenUniStreamSync(ctx)
if err != nil { if err != nil {
Expect(err).To(MatchError(context.DeadlineExceeded)) Expect(err).To(MatchError(context.DeadlineExceeded))
atomic.AddInt32(&numCanceled, 1) numCanceled.Add(1)
select { select {
case msg <- struct{}{}: case msg <- struct{}{}:
default: default:
@ -644,7 +644,7 @@ var _ = Describe("Stream Cancellations", func() {
} }
wg.Wait() wg.Wait()
count := atomic.LoadInt32(&numCanceled) count := numCanceled.Load()
fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count) fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count)
Expect(count).To(BeNumerically(">=", numStreams-maxIncomingStreams)) Expect(count).To(BeNumerically(">=", numStreams-maxIncomingStreams))
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())

View file

@ -23,7 +23,7 @@ var _ = Describe("Datagram test", func() {
var ( var (
serverConn, clientConn *net.UDPConn serverConn, clientConn *net.UDPConn
dropped, total int32 dropped, total atomic.Int32
) )
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) { startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
@ -81,9 +81,9 @@ var _ = Describe("Datagram test", func() {
} }
drop := mrand.Int()%10 == 0 drop := mrand.Int()%10 == 0
if drop { if drop {
atomic.AddInt32(&dropped, 1) dropped.Add(1)
} }
atomic.AddInt32(&total, 1) total.Add(1)
return drop return drop
}, },
}) })
@ -127,9 +127,9 @@ var _ = Describe("Datagram test", func() {
counter++ counter++
} }
numDropped := int(atomic.LoadInt32(&dropped)) numDropped := int(dropped.Load())
expVal := num - numDropped expVal := num - numDropped
fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total)) fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, total.Load())
fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num) fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num)
Expect(counter).To(And( Expect(counter).To(And(
BeNumerically(">", expVal*9/10), BeNumerically(">", expVal*9/10),

View file

@ -67,14 +67,14 @@ var _ = Describe("Drop Tests", func() {
fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay) fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay)
startTime := time.Now() startTime := time.Now()
var numDroppedPackets int32 var numDroppedPackets atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
if !d.Is(direction) { if !d.Is(direction) {
return false return false
} }
drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration)) drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration))
if drop { if drop {
atomic.AddInt32(&numDroppedPackets, 1) numDroppedPackets.Add(1)
} }
return drop return drop
}) })
@ -114,7 +114,7 @@ var _ = Describe("Drop Tests", func() {
Expect(b[0]).To(Equal(i)) Expect(b[0]).To(Equal(i))
} }
close(done) close(done)
numDropped := atomic.LoadInt32(&numDroppedPackets) numDropped := numDroppedPackets.Load()
fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped) fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped)
Expect(numDropped).To(BeNumerically(">", 0)) Expect(numDropped).To(BeNumerically(">", 0))
}) })

View file

@ -194,15 +194,15 @@ var _ = Describe("Handshake drop tests", func() {
Context(app.name, func() { Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32 var incoming, outgoing atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32 var p int32
//nolint:exhaustive //nolint:exhaustive
switch d { switch d {
case quicproxy.DirectionIncoming: case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1) p = incoming.Add(1)
case quicproxy.DirectionOutgoing: case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1) p = outgoing.Add(1)
} }
return p == 1 && d.Is(direction) return p == 1 && d.Is(direction)
}, doRetry, longCertChain) }, doRetry, longCertChain)
@ -210,15 +210,15 @@ var _ = Describe("Handshake drop tests", func() {
}) })
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32 var incoming, outgoing atomic.Int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32 var p int32
//nolint:exhaustive //nolint:exhaustive
switch d { switch d {
case quicproxy.DirectionIncoming: case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1) p = incoming.Add(1)
case quicproxy.DirectionOutgoing: case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1) p = outgoing.Add(1)
} }
return p == 2 && d.Is(direction) return p == 2 && d.Is(direction)
}, doRetry, longCertChain) }, doRetry, longCertChain)

View file

@ -20,7 +20,7 @@ import (
type listenerWrapper struct { type listenerWrapper struct {
http3.QUICEarlyListener http3.QUICEarlyListener
listenerClosed bool listenerClosed bool
count int32 count atomic.Int32
} }
func (ln *listenerWrapper) Close() error { func (ln *listenerWrapper) Close() error {
@ -29,14 +29,18 @@ func (ln *listenerWrapper) Close() error {
} }
func (ln *listenerWrapper) Faker() *fakeClosingListener { func (ln *listenerWrapper) Faker() *fakeClosingListener {
atomic.AddInt32(&ln.count, 1) ln.count.Add(1)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &fakeClosingListener{ln, 0, ctx, cancel} return &fakeClosingListener{
listenerWrapper: ln,
ctx: ctx,
cancel: cancel,
}
} }
type fakeClosingListener struct { type fakeClosingListener struct {
*listenerWrapper *listenerWrapper
closed int32 closed atomic.Bool
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -47,9 +51,9 @@ func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection
} }
func (ln *fakeClosingListener) Close() error { func (ln *fakeClosingListener) Close() error {
if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { if ln.closed.CompareAndSwap(false, true) {
ln.cancel() ln.cancel()
if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { if ln.listenerWrapper.count.Add(-1) == 0 {
ln.listenerWrapper.Close() ln.listenerWrapper.Close()
} }
} }
@ -145,8 +149,8 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
// and only the fake listener should be closed // and only the fake listener should be closed
Expect(server1.Close()).NotTo(HaveOccurred()) Expect(server1.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing1).Should(BeClosed()) Eventually(stoppedServing1).Should(BeClosed())
Expect(fake1.closed).To(Equal(int32(1))) Expect(fake1.closed.Load()).To(BeTrue())
Expect(fake2.closed).To(Equal(int32(0))) Expect(fake2.closed.Load()).To(BeFalse())
Expect(ln.listenerClosed).ToNot(BeTrue()) Expect(ln.listenerClosed).ToNot(BeTrue())
Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
@ -161,7 +165,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
// close the other server - both the fake and the actual listeners must close now // close the other server - both the fake and the actual listeners must close now
Expect(server2.Close()).NotTo(HaveOccurred()) Expect(server2.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing2).Should(BeClosed()) Eventually(stoppedServing2).Should(BeClosed())
Expect(fake2.closed).To(Equal(int32(1))) Expect(fake2.closed.Load()).To(BeTrue())
Expect(ln.listenerClosed).To(BeTrue()) Expect(ln.listenerClosed).To(BeTrue())
}) })
}) })

View file

@ -244,17 +244,17 @@ var _ = Describe("MITM test", func() {
Context("corrupting packets", func() { Context("corrupting packets", func() {
const idleTimeout = time.Second const idleTimeout = time.Second
var numCorrupted, numPackets int32 var numCorrupted, numPackets atomic.Int32
BeforeEach(func() { BeforeEach(func() {
numCorrupted = 0 numCorrupted.Store(0)
numPackets = 0 numPackets.Store(0)
serverConfig.MaxIdleTimeout = idleTimeout serverConfig.MaxIdleTimeout = idleTimeout
}) })
AfterEach(func() { AfterEach(func() {
num := atomic.LoadInt32(&numCorrupted) num := numCorrupted.Load()
fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets)) fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load())
Expect(num).To(BeNumerically(">=", 1)) Expect(num).To(BeNumerically(">=", 1))
// If the packet containing the CONNECTION_CLOSE is corrupted, // If the packet containing the CONNECTION_CLOSE is corrupted,
// we have to wait for the connection to time out. // we have to wait for the connection to time out.
@ -266,13 +266,13 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool { dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover() defer GinkgoRecover()
if dir == quicproxy.DirectionIncoming { if dir == quicproxy.DirectionIncoming {
atomic.AddInt32(&numPackets, 1) numPackets.Add(1)
if rand.Intn(interval) == 0 { if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw)) pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256)) raw[pos] = byte(rand.Intn(256))
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1) numCorrupted.Add(1)
return true return true
} }
} }
@ -286,13 +286,13 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool { dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover() defer GinkgoRecover()
if dir == quicproxy.DirectionOutgoing { if dir == quicproxy.DirectionOutgoing {
atomic.AddInt32(&numPackets, 1) numPackets.Add(1)
if rand.Intn(interval) == 0 { if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw)) pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256)) raw[pos] = byte(rand.Intn(256))
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1) numCorrupted.Add(1)
return true return true
} }
} }

View file

@ -22,12 +22,12 @@ type faultyConn struct {
net.PacketConn net.PacketConn
MaxPackets int32 MaxPackets int32
counter int32 counter atomic.Int32
} }
func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) { func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) {
n, addr, err := c.PacketConn.ReadFrom(p) n, addr, err := c.PacketConn.ReadFrom(p)
counter := atomic.AddInt32(&c.counter, 1) counter := c.counter.Add(1)
if counter <= c.MaxPackets { if counter <= c.MaxPackets {
return n, addr, err return n, addr, err
} }
@ -35,7 +35,7 @@ func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) {
} }
func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) { func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
counter := atomic.AddInt32(&c.counter, 1) counter := c.counter.Add(1)
if counter <= c.MaxPackets { if counter <= c.MaxPackets {
return c.PacketConn.WriteTo(p, addr) return c.PacketConn.WriteTo(p, addr)
} }

View file

@ -26,8 +26,8 @@ import (
var _ = Describe("0-RTT", func() { var _ = Describe("0-RTT", func() {
rtt := scaleDuration(5 * time.Millisecond) rtt := scaleDuration(5 * time.Millisecond)
runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *atomic.Uint32) {
var num0RTTPackets uint32 // to be used as an atomic var num0RTTPackets atomic.Uint32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
@ -38,7 +38,7 @@ var _ = Describe("0-RTT", func() {
hdr, _, rest, err := wire.ParsePacket(data) hdr, _, rest, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT { if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1) num0RTTPackets.Add(1)
break break
} }
data = rest data = rest
@ -257,7 +257,7 @@ var _ = Describe("0-RTT", func() {
Expect(numNewConnIDs).ToNot(BeZero()) Expect(numNewConnIDs).ToNot(BeZero())
} }
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
@ -348,10 +348,7 @@ var _ = Describe("0-RTT", func() {
}) })
It("transfers 0-RTT data, when 0-RTT packets are lost", func() { It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
var ( var num0RTTPackets, num0RTTDropped atomic.Uint32
num0RTTPackets uint32 // to be used as an atomic
num0RTTDropped uint32
)
tlsConf, clientConf := dialAndReceiveSessionTicket(nil) tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
@ -374,7 +371,7 @@ var _ = Describe("0-RTT", func() {
hdr, _, _, err := wire.ParsePacket(data) hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT { if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1) num0RTTPackets.Add(1)
} }
} }
return rtt / 2 return rtt / 2
@ -389,7 +386,7 @@ var _ = Describe("0-RTT", func() {
// drop 25% of the 0-RTT packets // drop 25% of the 0-RTT packets
drop := mrand.Intn(4) == 0 drop := mrand.Intn(4) == 0
if drop { if drop {
atomic.AddUint32(&num0RTTDropped, 1) num0RTTDropped.Add(1)
} }
return drop return drop
} }
@ -401,8 +398,8 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
num0RTT := atomic.LoadUint32(&num0RTTPackets) num0RTT := num0RTTPackets.Load()
numDropped := atomic.LoadUint32(&num0RTTDropped) numDropped := num0RTTDropped.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
Expect(numDropped).ToNot(BeZero()) Expect(numDropped).ToNot(BeZero())
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
@ -551,7 +548,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -580,7 +577,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -607,7 +604,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -759,7 +756,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -851,7 +848,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
Expect(receivedMessage).To(Equal(sentMessage)) Expect(receivedMessage).To(Equal(sentMessage))
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
@ -906,7 +903,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())

View file

@ -56,8 +56,8 @@ func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionSt
var _ = Describe("0-RTT", func() { var _ = Describe("0-RTT", func() {
rtt := scaleDuration(5 * time.Millisecond) rtt := scaleDuration(5 * time.Millisecond)
runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *atomic.Uint32) {
var num0RTTPackets uint32 // to be used as an atomic var num0RTTPackets atomic.Uint32
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
@ -68,7 +68,7 @@ var _ = Describe("0-RTT", func() {
hdr, _, rest, err := wire.ParsePacket(data) hdr, _, rest, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT { if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1) num0RTTPackets.Add(1)
break break
} }
data = rest data = rest
@ -289,7 +289,7 @@ var _ = Describe("0-RTT", func() {
Expect(numNewConnIDs).ToNot(BeZero()) Expect(numNewConnIDs).ToNot(BeZero())
} }
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
@ -382,10 +382,7 @@ var _ = Describe("0-RTT", func() {
}) })
It("transfers 0-RTT data, when 0-RTT packets are lost", func() { It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
var ( var num0RTTPackets, num0RTTDropped atomic.Uint32
num0RTTPackets uint32 // to be used as an atomic
num0RTTDropped uint32
)
tlsConf := getTLSConfig() tlsConf := getTLSConfig()
clientConf := getTLSClientConfig() clientConf := getTLSClientConfig()
@ -410,7 +407,7 @@ var _ = Describe("0-RTT", func() {
hdr, _, _, err := wire.ParsePacket(data) hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT { if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1) num0RTTPackets.Add(1)
} }
} }
return rtt / 2 return rtt / 2
@ -425,7 +422,7 @@ var _ = Describe("0-RTT", func() {
// drop 25% of the 0-RTT packets // drop 25% of the 0-RTT packets
drop := mrand.Intn(4) == 0 drop := mrand.Intn(4) == 0
if drop { if drop {
atomic.AddUint32(&num0RTTDropped, 1) num0RTTDropped.Add(1)
} }
return drop return drop
} }
@ -437,8 +434,8 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
num0RTT := atomic.LoadUint32(&num0RTTPackets) num0RTT := num0RTTPackets.Load()
numDropped := atomic.LoadUint32(&num0RTTDropped) numDropped := num0RTTDropped.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
Expect(numDropped).ToNot(BeZero()) Expect(numDropped).ToNot(BeZero())
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
@ -594,7 +591,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -627,7 +624,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -656,7 +653,7 @@ var _ = Describe("0-RTT", func() {
check0RTTRejected(ln, proxy.LocalPort(), clientConf) check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -812,7 +809,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
// The client should send 0-RTT packets, but the server doesn't process them. // The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
@ -981,7 +978,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(receivedMessage).To(Equal(sentMessage)) Expect(receivedMessage).To(Equal(sentMessage))
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
@ -1038,7 +1035,7 @@ var _ = Describe("0-RTT", func() {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(conn.CloseWithError(0, "")).To(Succeed())
num0RTT := atomic.LoadUint32(num0RTTPackets) num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())

View file

@ -141,7 +141,7 @@ var _ = Describe("QUIC Proxy", func() {
Context("Proxy tests", func() { Context("Proxy tests", func() {
var ( var (
serverConn *net.UDPConn serverConn *net.UDPConn
serverNumPacketsSent int32 serverNumPacketsSent atomic.Int32
serverReceivedPackets chan packetData serverReceivedPackets chan packetData
clientConn *net.UDPConn clientConn *net.UDPConn
proxy *QuicProxy proxy *QuicProxy
@ -159,9 +159,9 @@ var _ = Describe("QUIC Proxy", func() {
BeforeEach(func() { BeforeEach(func() {
stoppedReading = make(chan struct{}) stoppedReading = make(chan struct{})
serverReceivedPackets = make(chan packetData, 100) serverReceivedPackets = make(chan packetData, 100)
atomic.StoreInt32(&serverNumPacketsSent, 0) serverNumPacketsSent.Store(0)
// setup a dump UDP server // set up a dump UDP server
// in production this would be a QUIC server // in production this would be a QUIC server
raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -181,7 +181,7 @@ var _ = Describe("QUIC Proxy", func() {
data := buf[0:n] data := buf[0:n]
serverReceivedPackets <- packetData(data) serverReceivedPackets <- packetData(data)
// echo the packet // echo the packet
atomic.AddInt32(&serverNumPacketsSent, 1) serverNumPacketsSent.Add(1)
serverConn.WriteToUDP(data, addr) serverConn.WriteToUDP(data, addr)
} }
}() }()
@ -236,7 +236,7 @@ var _ = Describe("QUIC Proxy", func() {
}() }()
Eventually(serverReceivedPackets).Should(HaveLen(2)) Eventually(serverReceivedPackets).Should(HaveLen(2))
Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2)) Expect(serverNumPacketsSent.Load()).To(BeEquivalentTo(2))
Eventually(clientReceivedPackets).Should(HaveLen(2)) Eventually(clientReceivedPackets).Should(HaveLen(2))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar")) Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad")) Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
@ -245,14 +245,14 @@ var _ = Describe("QUIC Proxy", func() {
Context("Drop Callbacks", func() { Context("Drop Callbacks", func() {
It("drops incoming packets", func() { It("drops incoming packets", func() {
var counter int32 var counter atomic.Int32
opts := &Opts{ opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(), RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, _ []byte) bool { DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionIncoming { if d != DirectionIncoming {
return false return false
} }
return atomic.AddInt32(&counter, 1)%2 == 1 return counter.Add(1)%2 == 1
}, },
} }
startProxy(opts) startProxy(opts)
@ -267,14 +267,14 @@ var _ = Describe("QUIC Proxy", func() {
It("drops outgoing packets", func() { It("drops outgoing packets", func() {
const numPackets = 6 const numPackets = 6
var counter int32 var counter atomic.Int32
opts := &Opts{ opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(), RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, _ []byte) bool { DropPacket: func(d Direction, _ []byte) bool {
if d != DirectionOutgoing { if d != DirectionOutgoing {
return false return false
} }
return atomic.AddInt32(&counter, 1)%2 == 1 return counter.Add(1)%2 == 1
}, },
} }
startProxy(opts) startProxy(opts)
@ -315,7 +315,7 @@ var _ = Describe("QUIC Proxy", func() {
} }
It("delays incoming packets", func() { It("delays incoming packets", func() {
var counter int32 var counter atomic.Int32
opts := &Opts{ opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(), RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms // delay packet 1 by 200 ms
@ -325,7 +325,7 @@ var _ = Describe("QUIC Proxy", func() {
if d == DirectionOutgoing { if d == DirectionOutgoing {
return 0 return 0
} }
p := atomic.AddInt32(&counter, 1) p := counter.Add(1)
return time.Duration(p) * delay return time.Duration(p) * delay
}, },
} }
@ -349,7 +349,7 @@ var _ = Describe("QUIC Proxy", func() {
}) })
It("handles reordered packets", func() { It("handles reordered packets", func() {
var counter int32 var counter atomic.Int32
opts := &Opts{ opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(), RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 600 ms // delay packet 1 by 600 ms
@ -359,7 +359,7 @@ var _ = Describe("QUIC Proxy", func() {
if d == DirectionOutgoing { if d == DirectionOutgoing {
return 0 return 0
} }
p := atomic.AddInt32(&counter, 1) p := counter.Add(1)
return 600*time.Millisecond - time.Duration(p-1)*delay return 600*time.Millisecond - time.Duration(p-1)*delay
}, },
} }
@ -407,7 +407,7 @@ var _ = Describe("QUIC Proxy", func() {
It("delays outgoing packets", func() { It("delays outgoing packets", func() {
const numPackets = 3 const numPackets = 3
var counter int32 var counter atomic.Int32
opts := &Opts{ opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(), RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms // delay packet 1 by 200 ms
@ -417,7 +417,7 @@ var _ = Describe("QUIC Proxy", func() {
if d == DirectionIncoming { if d == DirectionIncoming {
return 0 return 0
} }
p := atomic.AddInt32(&counter, 1) p := counter.Add(1)
return time.Duration(p) * delay return time.Duration(p) * delay
}, },
} }