stop delay timers in the proxy when it is closed

This commit is contained in:
Marten Seemann 2019-07-01 15:41:13 +07:00
parent 7827cd61bc
commit 5479837a01

View file

@ -87,6 +87,9 @@ type QuicProxy struct {
dropPacket DropCallback
delayPacket DelayCallback
timerID uint64
timers map[uint64]*time.Timer
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
@ -127,6 +130,7 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
timers: make(map[uint64]*time.Timer),
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
@ -144,6 +148,9 @@ func (p *QuicProxy) Close() error {
return err
}
}
for _, t := range p.timers {
t.Stop()
}
return p.conn.Close()
}
@ -206,10 +213,19 @@ func (p *QuicProxy) runProxy() error {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = conn.ServerConn.Write(raw)
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = conn.ServerConn.Write(raw) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr())
@ -243,10 +259,18 @@ func (p *QuicProxy) runConnection(conn *connection) error {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr)