diff --git a/common/canceler/instance.go b/common/canceler/instance.go new file mode 100644 index 0000000..2f80be4 --- /dev/null +++ b/common/canceler/instance.go @@ -0,0 +1,57 @@ +package canceler + +import ( + "context" + "time" +) + +type Instance struct { + ctx context.Context + cancelFunc context.CancelFunc + timer *time.Timer + timeout time.Duration +} + +func New(ctx context.Context, cancelFunc context.CancelFunc, timeout time.Duration) *Instance { + instance := &Instance{ + ctx, + cancelFunc, + time.NewTimer(timeout), + timeout, + } + go instance.wait() + return instance +} + +func (i *Instance) Update() bool { + if !i.timer.Stop() { + return false + } + if !i.timer.Reset(i.timeout) { + return false + } + return true +} + +func (i *Instance) Timeout() time.Duration { + return i.timeout +} + +func (i *Instance) SetTimeout(timeout time.Duration) { + i.timeout = timeout + i.Update() +} + +func (i *Instance) wait() { + select { + case <-i.timer.C: + case <-i.ctx.Done(): + } + i.Close() +} + +func (i *Instance) Close() error { + i.timer.Stop() + i.cancelFunc() + return nil +} diff --git a/common/canceler/packet.go b/common/canceler/packet.go new file mode 100644 index 0000000..7833fdc --- /dev/null +++ b/common/canceler/packet.go @@ -0,0 +1,74 @@ +package canceler + +import ( + "context" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type PacketConn interface { + N.PacketConn + Timeout() time.Duration + SetTimeout(timeout time.Duration) +} + +type TimerPacketConn struct { + N.PacketConn + instance *Instance +} + +func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) { + if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { + oldTimeout := timeoutConn.Timeout() + if timeout < oldTimeout { + timeoutConn.SetTimeout(timeout) + } + return ctx, timeoutConn + } + err := conn.SetReadDeadline(time.Time{}) + if err == nil { + return NewTimeoutPacketConn(ctx, conn, timeout) + } + ctx, cancel := context.WithCancel(ctx) + instance := New(ctx, cancel, timeout) + return ctx, &TimerPacketConn{conn, instance} +} + +func (c *TimerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.PacketConn.ReadPacket(buffer) + if err == nil { + c.instance.Update() + } + return +} + +func (c *TimerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + err := c.PacketConn.WritePacket(buffer, destination) + if err == nil { + c.instance.Update() + } + return err +} + +func (c *TimerPacketConn) Timeout() time.Duration { + return c.instance.Timeout() +} + +func (c *TimerPacketConn) SetTimeout(timeout time.Duration) { + c.instance.SetTimeout(timeout) +} + +func (c *TimerPacketConn) Close() error { + return common.Close( + c.PacketConn, + c.instance, + ) +} + +func (c *TimerPacketConn) Upstream() any { + return c.PacketConn +} diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go new file mode 100644 index 0000000..83a12bb --- /dev/null +++ b/common/canceler/packet_timeout.go @@ -0,0 +1,73 @@ +package canceler + +import ( + "context" + "time" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type TimeoutPacketConn struct { + N.PacketConn + timeout time.Duration + cancel context.CancelFunc + active time.Time +} + +func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) { + ctx, cancel := context.WithCancel(ctx) + return ctx, &TimeoutPacketConn{ + PacketConn: conn, + timeout: timeout, + cancel: cancel, + } +} + +func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + for { + err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout)) + if err != nil { + return M.Socksaddr{}, err + } + destination, err = c.PacketConn.ReadPacket(buffer) + if err == nil { + c.active = time.Now() + return + } else if E.IsTimeout(err) { + if time.Since(c.active) > c.timeout { + c.cancel() + return + } + } else { + return M.Socksaddr{}, err + } + } +} + +func (c *TimeoutPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + err := c.PacketConn.WritePacket(buffer, destination) + if err == nil { + c.active = time.Now() + } + return err +} + +func (c *TimeoutPacketConn) Timeout() time.Duration { + return c.timeout +} + +func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { + c.timeout = timeout + c.PacketConn.SetReadDeadline(time.Now()) +} + +func (c *TimeoutPacketConn) Close() error { + return c.PacketConn.Close() +} + +func (c *TimeoutPacketConn) Upstream() any { + return c.PacketConn +}