rewrite the proxy to avoid packet reordering

This commit is contained in:
Marten Seemann 2020-06-19 22:57:07 +07:00
parent c956ca4447
commit 0baf16ea4e
3 changed files with 248 additions and 96 deletions

View file

@ -2,6 +2,7 @@ package quicproxy
import (
"net"
"sort"
"sync"
"time"
@ -13,6 +14,15 @@ import (
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPackets chan packetEntry
Incoming *queue
Outgoing *queue
}
func (c *connection) queuePacket(t time.Time, b []byte) {
c.incomingPackets <- packetEntry{Time: t, Raw: b}
}
// Direction is the direction a packet is sent.
@ -27,12 +37,63 @@ const (
DirectionBoth
)
type packetEntry struct {
Time time.Time
Raw []byte
}
type packetEntries []packetEntry
func (e packetEntries) Len() int { return len(e) }
func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) }
func (e packetEntries) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
type queue struct {
sync.Mutex
timer *utils.Timer
Packets packetEntries
}
func newQueue() *queue {
return &queue{timer: utils.NewTimer()}
}
func (q *queue) Add(e packetEntry) {
q.Lock()
q.Packets = append(q.Packets, e)
if len(q.Packets) > 1 {
lastIndex := len(q.Packets) - 1
if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) {
sort.Stable(q.Packets)
}
}
q.timer.Reset(q.Packets[0].Time)
q.Unlock()
}
func (q *queue) Get() []byte {
q.Lock()
raw := q.Packets[0].Raw
q.Packets = q.Packets[1:]
if len(q.Packets) > 0 {
q.timer.Reset(q.Packets[0].Time)
}
q.Unlock()
return raw
}
func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() }
func (q *queue) SetTimerRead() { q.timer.SetRead() }
func (q *queue) Close() { q.timer.Stop() }
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
return "Incoming"
case DirectionOutgoing:
return "outgoing"
return "Outgoing"
case DirectionBoth:
return "both"
default:
@ -81,15 +142,14 @@ type Opts struct {
type QuicProxy struct {
mutex sync.Mutex
closeChan chan struct{}
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
timerID uint64
timers map[uint64]*time.Timer
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
@ -127,10 +187,10 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
closeChan: make(chan struct{}),
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
timers: make(map[uint64]*time.Timer),
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
@ -143,13 +203,13 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
close(p.closeChan)
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
for _, t := range p.timers {
t.Stop()
c.Incoming.Close()
c.Outgoing.Close()
}
return p.conn.Close()
}
@ -170,8 +230,11 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
return nil, err
}
return &connection{
ClientAddr: cliAddr,
ServerConn: srvudp,
ClientAddr: cliAddr,
ServerConn: srvudp,
incomingPackets: make(chan packetEntry, 10),
Incoming: newQueue(),
Outgoing: newQueue(),
}, nil
}
@ -196,7 +259,8 @@ func (p *QuicProxy) runProxy() error {
return err
}
p.clientDict[saddr] = conn
go p.runConnection(conn)
go p.runIncomingConnection(conn)
go p.runOutgoingConnection(conn)
}
p.mutex.Unlock()
@ -207,75 +271,87 @@ func (p *QuicProxy) runProxy() error {
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, raw)
if delay != 0 {
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay)
}
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = conn.ServerConn.Write(raw) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr())
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
}
conn.queuePacket(now.Add(delay), raw)
}
}
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
outgoingPackets := make(chan packetEntry, 10)
go func() {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return
}
raw := buffer[0:n]
if p.dropPacket(DirectionOutgoing, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, raw)
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay)
}
outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw}
}
}
}()
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
if p.dropPacket(DirectionOutgoing, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, raw)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay)
}
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
select {
case <-p.closeChan:
return nil
case e := <-outgoingPackets:
conn.Outgoing.Add(e)
case <-conn.Outgoing.Timer():
conn.Outgoing.SetTimerRead()
if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
return err
}
}
}
}
func (p *QuicProxy) runIncomingConnection(conn *connection) error {
for {
select {
case <-p.closeChan:
return nil
case e := <-conn.incomingPackets:
// Send the packet to the server
conn.Incoming.Add(e)
case <-conn.Incoming.Timer():
conn.Incoming.SetTimerRead()
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
return err
}
}

View file

@ -19,11 +19,22 @@ import (
type packetData []byte
func isProxyRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runIncomingConnection") ||
strings.Contains(b.String(), "proxy.(*QuicProxy).runOutgoingConnection")
}
var _ = Describe("QUIC Proxy", func() {
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
b := &bytes.Buffer{}
hdr := wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Version: protocol.VersionTLS,
Length: 4 + protocol.ByteCount(len(payload)),
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
},
@ -36,6 +47,19 @@ var _ = Describe("QUIC Proxy", func() {
return raw
}
readPacketNumber := func(b []byte) protocol.PacketNumber {
hdr, data, _, err := wire.ParsePacket(b, 0)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return extHdr.PacketNumber
}
AfterEach(func() {
Eventually(isProxyRunning).Should(BeFalse())
})
Context("Proxy setup and teardown", func() {
It("sets up the UDPProxy", func() {
proxy, err := NewQuicProxy("localhost:0", nil)
@ -80,12 +104,6 @@ var _ = Describe("QUIC Proxy", func() {
})
It("stops listening for proxied connections", func() {
isConnRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection")
}
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err := net.ListenUDP("udp", serverAddr)
@ -94,16 +112,16 @@ var _ = Describe("QUIC Proxy", func() {
proxy, err := NewQuicProxy("localhost:0", &Opts{RemoteAddr: serverConn.LocalAddr().String()})
Expect(err).ToNot(HaveOccurred())
Expect(isConnRunning()).To(BeFalse())
Expect(isProxyRunning()).To(BeFalse())
// check that the proxy port is not in use anymore
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(isConnRunning).Should(BeTrue())
Eventually(isProxyRunning).Should(BeTrue())
Expect(proxy.Close()).To(Succeed())
Eventually(isConnRunning).Should(BeFalse())
Eventually(isProxyRunning).Should(BeFalse())
})
It("has the correct LocalAddr and LocalPort", func() {
@ -284,16 +302,16 @@ var _ = Describe("QUIC Proxy", func() {
})
Context("Delay Callback", func() {
expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) {
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt)
const delay = 200 * time.Millisecond
expectDelay := func(startTime time.Time, numRTTs int) {
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay)
Expect(time.Now()).To(SatisfyAll(
BeTemporally(">=", expectedReceiveTime),
BeTemporally("<", expectedReceiveTime.Add(rtt/2)),
BeTemporally("<", expectedReceiveTime.Add(delay/2)),
))
}
It("delays incoming packets", func() {
delay := 300 * time.Millisecond
var counter int32
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
@ -317,16 +335,75 @@ var _ = Describe("QUIC Proxy", func() {
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
expectDelay(start, 1)
Eventually(serverReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
expectDelay(start, 2)
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
expectDelay(start, 3)
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
})
It("handles reordered packets", func() {
var counter int32
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 600 ms
// delay packet 2 by 400 ms
// delay packet 3 by 200 ms
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
p := atomic.AddInt32(&counter, 1)
return 600*time.Millisecond - time.Duration(p-1)*delay
},
}
startProxy(opts)
// send 3 packets
start := time.Now()
for i := 1; i <= 3; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(1))
expectDelay(start, 1)
Eventually(serverReceivedPackets).Should(HaveLen(2))
expectDelay(start, 2)
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, 3)
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
})
It("doesn't reorder packets when a constant delay is used", func() {
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
DelayPacket: func(d Direction, _ []byte) time.Duration {
if d == DirectionOutgoing {
return 0
}
return 100 * time.Millisecond
},
}
startProxy(opts)
// send 100 packets
for i := 0; i < 100; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(100))
for i := 0; i < 100; i++ {
Expect(readPacketNumber(<-serverReceivedPackets)).To(Equal(protocol.PacketNumber(i)))
}
})
It("delays outgoing packets", func() {
const numPackets = 3
delay := 300 * time.Millisecond
var counter int32
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
@ -365,13 +442,16 @@ var _ = Describe("QUIC Proxy", func() {
}
// the packets should have arrived immediately at the server
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 0)
expectDelay(start, 0)
Eventually(clientReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
expectDelay(start, 1)
Eventually(clientReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
expectDelay(start, 2)
Eventually(clientReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
expectDelay(start, 3)
Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(1)))
Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(2)))
Expect(readPacketNumber(<-clientReceivedPackets)).To(Equal(protocol.PacketNumber(3)))
})
})
})