diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 2025c2a3..8f941564 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -418,18 +418,15 @@ var _ = Describe("MITM test", func() { // client connection closes immediately on receiving ack for unsent packet It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { + clientAddr := clientConn.LocalAddr() delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { - defer GinkgoRecover() - hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - if hdr.Type != protocol.PacketTypeInitial { return 0 } - - sendForgedInitialPacketWithAck(serverConn, clientConn.LocalAddr(), hdr) + sendForgedInitialPacketWithAck(serverConn, clientAddr, hdr) } return rtt } @@ -439,7 +436,6 @@ var _ = Describe("MITM test", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ProtocolViolation)) Expect(err.Error()).To(ContainSubstring("Received ACK for an unsent packet")) }) - }) }) } diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index 1c1d224a..1e5c90ec 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -2,6 +2,7 @@ package quicproxy import ( "net" + "sort" "sync" "time" @@ -13,6 +14,15 @@ import ( type connection struct { ClientAddr *net.UDPAddr // Address of the client ServerConn *net.UDPConn // UDP connection to server + + incomingPackets chan packetEntry + + Incoming *queue + Outgoing *queue +} + +func (c *connection) queuePacket(t time.Time, b []byte) { + c.incomingPackets <- packetEntry{Time: t, Raw: b} } // Direction is the direction a packet is sent. @@ -27,12 +37,63 @@ const ( DirectionBoth ) +type packetEntry struct { + Time time.Time + Raw []byte +} + +type packetEntries []packetEntry + +func (e packetEntries) Len() int { return len(e) } +func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) } +func (e packetEntries) Swap(i, j int) { e[i], e[j] = e[j], e[i] } + +type queue struct { + sync.Mutex + + timer *utils.Timer + Packets packetEntries +} + +func newQueue() *queue { + return &queue{timer: utils.NewTimer()} +} + +func (q *queue) Add(e packetEntry) { + q.Lock() + q.Packets = append(q.Packets, e) + if len(q.Packets) > 1 { + lastIndex := len(q.Packets) - 1 + if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) { + sort.Stable(q.Packets) + } + } + q.timer.Reset(q.Packets[0].Time) + q.Unlock() +} + +func (q *queue) Get() []byte { + q.Lock() + raw := q.Packets[0].Raw + q.Packets = q.Packets[1:] + if len(q.Packets) > 0 { + q.timer.Reset(q.Packets[0].Time) + } + q.Unlock() + return raw +} + +func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() } +func (q *queue) SetTimerRead() { q.timer.SetRead() } + +func (q *queue) Close() { q.timer.Stop() } + func (d Direction) String() string { switch d { case DirectionIncoming: - return "incoming" + return "Incoming" case DirectionOutgoing: - return "outgoing" + return "Outgoing" case DirectionBoth: return "both" default: @@ -81,15 +142,14 @@ type Opts struct { type QuicProxy struct { mutex sync.Mutex + closeChan chan struct{} + conn *net.UDPConn serverAddr *net.UDPAddr dropPacket DropCallback delayPacket DelayCallback - timerID uint64 - timers map[uint64]*time.Timer - // Mapping from client addresses (as host:port) to connection clientDict map[string]*connection @@ -127,10 +187,10 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) { p := QuicProxy{ clientDict: make(map[string]*connection), conn: conn, + closeChan: make(chan struct{}), serverAddr: raddr, dropPacket: packetDropper, delayPacket: packetDelayer, - timers: make(map[uint64]*time.Timer), logger: utils.DefaultLogger.WithPrefix("proxy"), } @@ -143,13 +203,13 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) { func (p *QuicProxy) Close() error { p.mutex.Lock() defer p.mutex.Unlock() + close(p.closeChan) for _, c := range p.clientDict { if err := c.ServerConn.Close(); err != nil { return err } - } - for _, t := range p.timers { - t.Stop() + c.Incoming.Close() + c.Outgoing.Close() } return p.conn.Close() } @@ -170,8 +230,11 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { return nil, err } return &connection{ - ClientAddr: cliAddr, - ServerConn: srvudp, + ClientAddr: cliAddr, + ServerConn: srvudp, + incomingPackets: make(chan packetEntry, 10), + Incoming: newQueue(), + Outgoing: newQueue(), }, nil } @@ -196,7 +259,8 @@ func (p *QuicProxy) runProxy() error { return err } p.clientDict[saddr] = conn - go p.runConnection(conn) + go p.runIncomingConnection(conn) + go p.runOutgoingConnection(conn) } p.mutex.Unlock() @@ -207,75 +271,87 @@ func (p *QuicProxy) runProxy() error { continue } - // Send the packet to the server delay := p.delayPacket(DirectionIncoming, raw) - if delay != 0 { + if delay == 0 { if p.logger.Debug() { - p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay) - } - - p.mutex.Lock() - p.timerID++ - id := p.timerID - timer := time.AfterFunc(delay, func() { - _, _ = conn.ServerConn.Write(raw) // TODO: handle error - - p.mutex.Lock() - delete(p.timers, id) - p.mutex.Unlock() - }) - p.timers[id] = timer - p.mutex.Unlock() - } else { - if p.logger.Debug() { - p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr()) + p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr()) } if _, err := conn.ServerConn.Write(raw); err != nil { return err } + } else { + now := time.Now() + if p.logger.Debug() { + p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay) + } + conn.queuePacket(now.Add(delay), raw) } } } // runConnection handles packets from server to a single client -func (p *QuicProxy) runConnection(conn *connection) error { +func (p *QuicProxy) runOutgoingConnection(conn *connection) error { + outgoingPackets := make(chan packetEntry, 10) + go func() { + for { + buffer := make([]byte, protocol.MaxReceivePacketSize) + n, err := conn.ServerConn.Read(buffer) + if err != nil { + return + } + raw := buffer[0:n] + + if p.dropPacket(DirectionOutgoing, raw) { + if p.logger.Debug() { + p.logger.Debugf("dropping outgoing packet(%d bytes)", n) + } + continue + } + + delay := p.delayPacket(DirectionOutgoing, raw) + if delay == 0 { + if p.logger.Debug() { + p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr) + } + if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { + return + } + } else { + now := time.Now() + if p.logger.Debug() { + p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay) + } + outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw} + } + } + }() + for { - buffer := make([]byte, protocol.MaxReceivePacketSize) - n, err := conn.ServerConn.Read(buffer) - if err != nil { - return err - } - raw := buffer[0:n] - - if p.dropPacket(DirectionOutgoing, raw) { - if p.logger.Debug() { - p.logger.Debugf("dropping outgoing packet(%d bytes)", n) - } - continue - } - - delay := p.delayPacket(DirectionOutgoing, raw) - if delay != 0 { - if p.logger.Debug() { - p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay) - } - - p.mutex.Lock() - p.timerID++ - id := p.timerID - timer := time.AfterFunc(delay, func() { - _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error - p.mutex.Lock() - delete(p.timers, id) - p.mutex.Unlock() - }) - p.timers[id] = timer - p.mutex.Unlock() - } else { - if p.logger.Debug() { - p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr) - } - if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { + select { + case <-p.closeChan: + return nil + case e := <-outgoingPackets: + conn.Outgoing.Add(e) + case <-conn.Outgoing.Timer(): + conn.Outgoing.SetTimerRead() + if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil { + return err + } + } + } +} + +func (p *QuicProxy) runIncomingConnection(conn *connection) error { + for { + select { + case <-p.closeChan: + return nil + case e := <-conn.incomingPackets: + // Send the packet to the server + conn.Incoming.Add(e) + case <-conn.Incoming.Timer(): + conn.Incoming.SetTimerRead() + if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil { return err } } diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 1b1403c6..ca6f2a08 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -19,11 +19,22 @@ import ( type packetData []byte +func isProxyRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "proxy.(*QuicProxy).runIncomingConnection") || + strings.Contains(b.String(), "proxy.(*QuicProxy).runOutgoingConnection") +} + var _ = Describe("QUIC Proxy", func() { makePacket := func(p protocol.PacketNumber, payload []byte) []byte { b := &bytes.Buffer{} hdr := wire.ExtendedHeader{ Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Version: protocol.VersionTLS, + Length: 4 + protocol.ByteCount(len(payload)), DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, }, @@ -36,6 +47,19 @@ var _ = Describe("QUIC Proxy", func() { return raw } + readPacketNumber := func(b []byte) protocol.PacketNumber { + hdr, data, _, err := wire.ParsePacket(b, 0) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return extHdr.PacketNumber + } + + AfterEach(func() { + Eventually(isProxyRunning).Should(BeFalse()) + }) + Context("Proxy setup and teardown", func() { It("sets up the UDPProxy", func() { proxy, err := NewQuicProxy("localhost:0", nil) @@ -80,12 +104,6 @@ var _ = Describe("QUIC Proxy", func() { }) It("stops listening for proxied connections", func() { - isConnRunning := func() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection") - } - serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) serverConn, err := net.ListenUDP("udp", serverAddr) @@ -94,16 +112,16 @@ var _ = Describe("QUIC Proxy", func() { proxy, err := NewQuicProxy("localhost:0", &Opts{RemoteAddr: serverConn.LocalAddr().String()}) Expect(err).ToNot(HaveOccurred()) - Expect(isConnRunning()).To(BeFalse()) + Expect(isProxyRunning()).To(BeFalse()) // check that the proxy port is not in use anymore conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) Expect(err).ToNot(HaveOccurred()) _, err = conn.Write(makePacket(1, []byte("foobar"))) Expect(err).ToNot(HaveOccurred()) - Eventually(isConnRunning).Should(BeTrue()) + Eventually(isProxyRunning).Should(BeTrue()) Expect(proxy.Close()).To(Succeed()) - Eventually(isConnRunning).Should(BeFalse()) + Eventually(isProxyRunning).Should(BeFalse()) }) It("has the correct LocalAddr and LocalPort", func() { @@ -284,16 +302,16 @@ var _ = Describe("QUIC Proxy", func() { }) Context("Delay Callback", func() { - expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) { - expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt) + const delay = 200 * time.Millisecond + expectDelay := func(startTime time.Time, numRTTs int) { + expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay) Expect(time.Now()).To(SatisfyAll( BeTemporally(">=", expectedReceiveTime), - BeTemporally("<", expectedReceiveTime.Add(rtt/2)), + BeTemporally("<", expectedReceiveTime.Add(delay/2)), )) } It("delays incoming packets", func() { - delay := 300 * time.Millisecond var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), @@ -317,16 +335,75 @@ var _ = Describe("QUIC Proxy", func() { Expect(err).ToNot(HaveOccurred()) } Eventually(serverReceivedPackets).Should(HaveLen(1)) - expectDelay(start, delay, 1) + expectDelay(start, 1) Eventually(serverReceivedPackets).Should(HaveLen(2)) - expectDelay(start, delay, 2) + expectDelay(start, 2) Eventually(serverReceivedPackets).Should(HaveLen(3)) - expectDelay(start, delay, 3) + expectDelay(start, 3) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1))) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2))) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3))) + }) + + It("handles reordered packets", func() { + var counter int32 + opts := &Opts{ + RemoteAddr: serverConn.LocalAddr().String(), + // delay packet 1 by 600 ms + // delay packet 2 by 400 ms + // delay packet 3 by 200 ms + DelayPacket: func(d Direction, _ []byte) time.Duration { + if d == DirectionOutgoing { + return 0 + } + p := atomic.AddInt32(&counter, 1) + return 600*time.Millisecond - time.Duration(p-1)*delay + }, + } + startProxy(opts) + + // send 3 packets + start := time.Now() + for i := 1; i <= 3; i++ { + _, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) + Expect(err).ToNot(HaveOccurred()) + } + Eventually(serverReceivedPackets).Should(HaveLen(1)) + expectDelay(start, 1) + Eventually(serverReceivedPackets).Should(HaveLen(2)) + expectDelay(start, 2) + Eventually(serverReceivedPackets).Should(HaveLen(3)) + expectDelay(start, 3) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3))) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2))) + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1))) + }) + + It("doesn't reorder packets when a constant delay is used", func() { + opts := &Opts{ + RemoteAddr: serverConn.LocalAddr().String(), + DelayPacket: func(d Direction, _ []byte) time.Duration { + if d == DirectionOutgoing { + return 0 + } + return 100 * time.Millisecond + }, + } + startProxy(opts) + + // send 100 packets + for i := 0; i < 100; i++ { + _, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) + Expect(err).ToNot(HaveOccurred()) + } + Eventually(serverReceivedPackets).Should(HaveLen(100)) + for i := 0; i < 100; i++ { + Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(i))) + } }) It("delays outgoing packets", func() { const numPackets = 3 - delay := 300 * time.Millisecond var counter int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), @@ -365,13 +442,16 @@ var _ = Describe("QUIC Proxy", func() { } // the packets should have arrived immediately at the server Eventually(serverReceivedPackets).Should(HaveLen(3)) - expectDelay(start, delay, 0) + expectDelay(start, 0) Eventually(clientReceivedPackets).Should(HaveLen(1)) - expectDelay(start, delay, 1) + expectDelay(start, 1) Eventually(clientReceivedPackets).Should(HaveLen(2)) - expectDelay(start, delay, 2) + expectDelay(start, 2) Eventually(clientReceivedPackets).Should(HaveLen(3)) - expectDelay(start, delay, 3) + expectDelay(start, 3) + Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(1))) + Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(2))) + Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(3))) }) }) })