diff --git a/common/canceler/instance.go b/common/canceler/instance.go index 05faa91..c47270d 100644 --- a/common/canceler/instance.go +++ b/common/canceler/instance.go @@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration { return i.timeout } -func (i *Instance) SetTimeout(timeout time.Duration) { +func (i *Instance) SetTimeout(timeout time.Duration) bool { i.timeout = timeout - i.Update() + return i.Update() } func (i *Instance) wait() { diff --git a/common/canceler/packet.go b/common/canceler/packet.go index fb4ad84..46cf9a0 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -13,7 +13,7 @@ import ( type PacketConn interface { N.PacketConn Timeout() time.Duration - SetTimeout(timeout time.Duration) + SetTimeout(timeout time.Duration) bool } type TimerPacketConn struct { @@ -24,10 +24,12 @@ type TimerPacketConn struct { func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { oldTimeout := timeoutConn.Timeout() - if timeout < oldTimeout { - timeoutConn.SetTimeout(timeout) + if timeout >= oldTimeout { + return ctx, conn + } + if timeoutConn.SetTimeout(timeout) { + return ctx, conn } - return ctx, conn } err := conn.SetReadDeadline(time.Time{}) if err == nil { @@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration { return c.instance.Timeout() } -func (c *TimerPacketConn) SetTimeout(timeout time.Duration) { - c.instance.SetTimeout(timeout) +func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool { + return c.instance.SetTimeout(timeout) } func (c *TimerPacketConn) Close() error { diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go index ab5c760..a679567 100644 --- a/common/canceler/packet_timeout.go +++ b/common/canceler/packet_timeout.go @@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration { return c.timeout } -func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { +func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool { c.timeout = timeout - c.PacketConn.SetReadDeadline(time.Now()) + return c.PacketConn.SetReadDeadline(time.Now()) == nil } func (c *TimeoutPacketConn) Close() error { diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 9d5bfa9..8ae4557 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -4,9 +4,11 @@ import ( "io" "net" "os" + "sync" "time" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/pipe" @@ -15,15 +17,18 @@ import ( type Conn interface { N.PacketConn SetHandler(handler N.UDPHandlerEx) + canceler.PacketConn } var _ Conn = (*natConn)(nil) type natConn struct { + service *Service writer N.PacketWriter localAddr M.Socksaddr handler N.UDPHandlerEx packetChan chan *N.PacketBuffer + closeOnce sync.Once doneChan chan struct{} readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions @@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error return c.writer.WritePacket(buffer, destination) } +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case packet := <-c.packetChan: + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + N.PutPacketBuffer(packet) + return + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + func (c *natConn) SetHandler(handler N.UDPHandlerEx) { select { case <-c.doneChan: @@ -68,31 +92,22 @@ fetch: } } -func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { - c.readWaitOptions = options - return false +func (c *natConn) Timeout() time.Duration { + rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort()) + if !loaded || rawConn != c { + return 0 + } + return time.Until(lifetime) } -func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case packet := <-c.packetChan: - buffer = c.readWaitOptions.Copy(packet.Buffer) - destination = packet.Destination - N.PutPacketBuffer(packet) - return - case <-c.doneChan: - return nil, M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): - return nil, M.Socksaddr{}, os.ErrDeadlineExceeded - } +func (c *natConn) SetTimeout(timeout time.Duration) bool { + return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) } func (c *natConn) Close() error { - select { - case <-c.doneChan: - default: + c.closeOnce.Do(func() { close(c.doneChan) - } + }) return nil } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index ac5da1d..f5485af 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -65,6 +65,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati return } conn = &natConn{ + service: s, writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64),