stop delay timers in the proxy when it is closed

This commit is contained in:
Marten Seemann 2019-07-01 15:41:13 +07:00
parent 7827cd61bc
commit 5479837a01

View file

@ -87,6 +87,9 @@ type QuicProxy struct {
dropPacket DropCallback dropPacket DropCallback
delayPacket DelayCallback delayPacket DelayCallback
timerID uint64
timers map[uint64]*time.Timer
// Mapping from client addresses (as host:port) to connection // Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection clientDict map[string]*connection
@ -127,6 +130,7 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
serverAddr: raddr, serverAddr: raddr,
dropPacket: packetDropper, dropPacket: packetDropper,
delayPacket: packetDelayer, delayPacket: packetDelayer,
timers: make(map[uint64]*time.Timer),
logger: utils.DefaultLogger.WithPrefix("proxy"), logger: utils.DefaultLogger.WithPrefix("proxy"),
} }
@ -144,6 +148,9 @@ func (p *QuicProxy) Close() error {
return err return err
} }
} }
for _, t := range p.timers {
t.Stop()
}
return p.conn.Close() return p.conn.Close()
} }
@ -206,10 +213,19 @@ func (p *QuicProxy) runProxy() error {
if p.logger.Debug() { if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay) p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay)
} }
time.AfterFunc(delay, func() {
// TODO: handle error p.mutex.Lock()
_, _ = conn.ServerConn.Write(raw) p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = conn.ServerConn.Write(raw) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
}) })
p.timers[id] = timer
p.mutex.Unlock()
} else { } else {
if p.logger.Debug() { if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr()) p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr())
@ -243,10 +259,18 @@ func (p *QuicProxy) runConnection(conn *connection) error {
if p.logger.Debug() { if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay) p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay)
} }
time.AfterFunc(delay, func() {
// TODO: handle error p.mutex.Lock()
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
}) })
p.timers[id] = timer
p.mutex.Unlock()
} else { } else {
if p.logger.Debug() { if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr) p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr)