diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index 5a20ccd4..dfd3971c 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -39,7 +39,7 @@ var _ = Describe("Drop Tests", func() { serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(dir quicproxy.Direction, packetCount uint64) time.Duration { + DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { return 5 * time.Millisecond // 10ms RTT }, DropPacket: dropCallback, @@ -75,7 +75,7 @@ var _ = Describe("Drop Tests", func() { startTime := time.Now() var numDroppedPackets int32 - startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool { + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { if !d.Is(direction) { return false } diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 724a72cc..008743d9 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -5,6 +5,7 @@ import ( "fmt" mrand "math/rand" "net" + "sync/atomic" "time" quic "github.com/lucas-clemente/quic-go" @@ -161,21 +162,37 @@ var _ = Describe("Handshake drop tests", func() { Context(app.name, func() { It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() { - startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } return p == 1 && d.Is(direction) }, version) app.run(version) }) It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() { - startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } return p == 2 && d.Is(direction) }, version) app.run(version) }) It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() { - startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool { + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { return d.Is(direction) && stochasticDropper(5) }, version) app.run(version) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 01053295..90a4c3ac 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() { // start the proxy proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: server.Addr().String(), - DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration { return rtt / 2 }, + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, }) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index eff2108a..a77a8aa3 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -55,7 +55,7 @@ var _ = Describe("non-zero RTT", func() { serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(d quicproxy.Direction, p uint64) time.Duration { + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, }) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index aae2ed74..a8e1fe43 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -48,7 +48,7 @@ var _ = Describe("Stateless Resets", func() { proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DropPacket: func(d quicproxy.Direction, p uint64) bool { + DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Get() }, }) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 8a752e57..b4d7cf78 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -82,7 +82,7 @@ var _ = Describe("Timeout tests", func() { proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DropPacket: func(d quicproxy.Direction, p uint64) bool { + DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Get() }, }) diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index c661df80..bd4b7f9f 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -3,7 +3,6 @@ package quicproxy import ( "net" "sync" - "sync/atomic" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -14,9 +13,6 @@ import ( type connection struct { ClientAddr *net.UDPAddr // Address of the client ServerConn *net.UDPConn // UDP connection to server - - incomingPacketCounter uint64 - outgoingPacketCounter uint64 } // Direction is the direction a packet is sent. @@ -54,18 +50,18 @@ func (d Direction) Is(dir Direction) bool { } // DropCallback is a callback that determines which packet gets dropped. -type DropCallback func(dir Direction, packetCount uint64) bool +type DropCallback func(dir Direction, packet []byte) bool // NoDropper doesn't drop packets. -var NoDropper DropCallback = func(Direction, uint64) bool { +var NoDropper DropCallback = func(Direction, []byte) bool { return false } // DelayCallback is a callback that determines how much delay to apply to a packet. -type DelayCallback func(dir Direction, packetCount uint64) time.Duration +type DelayCallback func(dir Direction, packet []byte) time.Duration // NoDelay doesn't apply a delay. -var NoDelay DelayCallback = func(Direction, uint64) time.Duration { +var NoDelay DelayCallback = func(Direction, []byte) time.Duration { return 0 } @@ -197,20 +193,18 @@ func (p *QuicProxy) runProxy() error { } p.mutex.Unlock() - packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) - - if p.dropPacket(DirectionIncoming, packetCount) { + if p.dropPacket(DirectionIncoming, raw) { if p.logger.Debug() { - p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n) + p.logger.Debugf("dropping incoming packet(%d bytes)", n) } continue } // Send the packet to the server - delay := p.delayPacket(DirectionIncoming, packetCount) + delay := p.delayPacket(DirectionIncoming, raw) if delay != 0 { if p.logger.Debug() { - p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay) + p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay) } time.AfterFunc(delay, func() { // TODO: handle error @@ -218,7 +212,7 @@ func (p *QuicProxy) runProxy() error { }) } else { if p.logger.Debug() { - p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr()) + p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr()) } if _, err := conn.ServerConn.Write(raw); err != nil { return err @@ -237,19 +231,17 @@ func (p *QuicProxy) runConnection(conn *connection) error { } raw := buffer[0:n] - packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) - - if p.dropPacket(DirectionOutgoing, packetCount) { + if p.dropPacket(DirectionOutgoing, raw) { if p.logger.Debug() { - p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n) + p.logger.Debugf("dropping outgoing packet(%d bytes)", n) } continue } - delay := p.delayPacket(DirectionOutgoing, packetCount) + delay := p.delayPacket(DirectionOutgoing, raw) if delay != 0 { if p.logger.Debug() { - p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay) + p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay) } time.AfterFunc(delay, func() { // TODO: handle error @@ -257,7 +249,7 @@ func (p *QuicProxy) runConnection(conn *connection) error { }) } else { if p.logger.Debug() { - p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr) + p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr) } if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { return err diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index e50563be..029fa718 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -135,17 +135,6 @@ var _ = Describe("QUIC Proxy", func() { Expect(err).ToNot(HaveOccurred()) } - // getClientDict returns a copy of the clientDict map - getClientDict := func() map[string]*connection { - d := make(map[string]*connection) - proxy.mutex.Lock() - defer proxy.mutex.Unlock() - for k, v := range proxy.clientDict { - d[k] = v - } - return d - } - BeforeEach(func() { stoppedReading = make(chan struct{}) serverReceivedPackets = make(chan packetData, 100) @@ -191,18 +180,11 @@ var _ = Describe("QUIC Proxy", func() { _, err := clientConn.Write(makePacket(1, []byte("foobar"))) Expect(err).ToNot(HaveOccurred()) - Eventually(getClientDict).Should(HaveLen(1)) - var conn *connection - for _, conn = range getClientDict() { - Eventually(func() uint64 { return atomic.LoadUint64(&conn.incomingPacketCounter) }).Should(Equal(uint64(1))) - } - // send the second packet _, err = clientConn.Write(makePacket(2, []byte("decafbad"))) Expect(err).ToNot(HaveOccurred()) Eventually(serverReceivedPackets).Should(HaveLen(2)) - Expect(getClientDict()).To(HaveLen(1)) Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar")) Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad")) }) @@ -213,23 +195,10 @@ var _ = Describe("QUIC Proxy", func() { _, err := clientConn.Write(makePacket(1, []byte("foobar"))) Expect(err).ToNot(HaveOccurred()) - Eventually(getClientDict).Should(HaveLen(1)) - var key string - var conn *connection - for key, conn = range getClientDict() { - Eventually(func() uint64 { return atomic.LoadUint64(&conn.outgoingPacketCounter) }).Should(Equal(uint64(1))) - } - // send the second packet _, err = clientConn.Write(makePacket(2, []byte("decafbad"))) Expect(err).ToNot(HaveOccurred()) - Expect(getClientDict()).To(HaveLen(1)) - Eventually(func() uint64 { - conn := getClientDict()[key] - return atomic.LoadUint64(&conn.outgoingPacketCounter) - }).Should(BeEquivalentTo(2)) - clientReceivedPackets := make(chan packetData, 2) // receive the packets echoed by the server on client side go func() { @@ -255,10 +224,14 @@ var _ = Describe("QUIC Proxy", func() { Context("Drop Callbacks", func() { It("drops incoming packets", func() { + var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), - DropPacket: func(d Direction, p uint64) bool { - return d == DirectionIncoming && p%2 == 0 + DropPacket: func(d Direction, _ []byte) bool { + if d != DirectionIncoming { + return false + } + return atomic.AddInt32(&counter, 1)%2 == 1 }, } startProxy(opts) @@ -273,10 +246,14 @@ var _ = Describe("QUIC Proxy", func() { It("drops outgoing packets", func() { const numPackets = 6 + var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), - DropPacket: func(d Direction, p uint64) bool { - return d == DirectionOutgoing && p%2 == 0 + DropPacket: func(d Direction, _ []byte) bool { + if d != DirectionOutgoing { + return false + } + return atomic.AddInt32(&counter, 1)%2 == 1 }, } startProxy(opts) @@ -317,15 +294,17 @@ var _ = Describe("QUIC Proxy", func() { It("delays incoming packets", func() { delay := 300 * time.Millisecond + var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... - DelayPacket: func(d Direction, p uint64) time.Duration { + DelayPacket: func(d Direction, _ []byte) time.Duration { if d == DirectionOutgoing { return 0 } + p := atomic.AddInt32(&counter, 1) return time.Duration(p) * delay }, } @@ -348,15 +327,17 @@ var _ = Describe("QUIC Proxy", func() { It("delays outgoing packets", func() { const numPackets = 3 delay := 300 * time.Millisecond + var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... - DelayPacket: func(d Direction, p uint64) time.Duration { + DelayPacket: func(d Direction, _ []byte) time.Duration { if d == DirectionIncoming { return 0 } + p := atomic.AddInt32(&counter, 1) return time.Duration(p) * delay }, }