Compare commits

...

3 commits

Author SHA1 Message Date
世界
76d0c1d33a
Fix AppendClose 2024-11-26 12:01:46 +08:00
世界
3796d32db3
udpnat2: Add purge expire ticker 2024-11-26 11:30:51 +08:00
世界
717d9c3f2d
udpnat2: Implement set timeout for nat conn 2024-11-26 11:30:38 +08:00
6 changed files with 83 additions and 34 deletions

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) {
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
i.Update()
return i.Update()
}
func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration)
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
@ -24,10 +24,12 @@ type TimerPacketConn struct {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
if timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
return ctx, conn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
c.instance.SetTimeout(timeout)
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {

View file

@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now())
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {

View file

@ -74,9 +74,10 @@ type ExtendedConn interface {
type CloseHandlerFunc = func(it error)
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
if onClose == nil {
panic("nil onClose")
}
if parent == nil {
return parent
} else if onClose == nil {
return onClose
}
return func(it error) {

View file

@ -4,9 +4,11 @@ import (
"io"
"net"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/canceler"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
@ -15,15 +17,18 @@ import (
type Conn interface {
N.PacketConn
SetHandler(handler N.UDPHandlerEx)
canceler.PacketConn
}
var _ Conn = (*natConn)(nil)
type natConn struct {
service *Service
writer N.PacketWriter
localAddr M.Socksaddr
handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
closeOnce sync.Once
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error
return c.writer.WritePacket(buffer, destination)
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
select {
case <-c.doneChan:
@ -68,31 +92,22 @@ fetch:
}
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
func (c *natConn) Timeout() time.Duration {
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort())
if !loaded || rawConn != c {
return 0
}
return time.Until(lifetime)
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
func (c *natConn) SetTimeout(timeout time.Duration) bool {
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
}
func (c *natConn) Close() error {
select {
case <-c.doneChan:
default:
c.closeOnce.Do(func() {
close(c.doneChan)
}
})
return nil
}

View file

@ -3,6 +3,7 @@ package udpnat
import (
"context"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing/common"
@ -18,6 +19,10 @@ type Service struct {
handler N.UDPConnectionHandlerEx
prepare PrepareFunc
metrics Metrics
timeout time.Duration
closeOnce sync.Once
doneChan chan struct{}
}
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
@ -50,12 +55,38 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
conn.Close()
})
return &Service{
cache: cache,
handler: handler,
prepare: prepare,
cache: cache,
handler: handler,
prepare: prepare,
timeout: timeout,
doneChan: make(chan struct{}),
}
}
func (s *Service) Start() error {
ticker := time.NewTicker(s.timeout)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.PurgeExpired()
case <-s.doneChan:
s.Purge()
return
}
}
}()
return nil
}
func (s *Service) Close() error {
s.closeOnce.Do(func() {
close(s.doneChan)
})
return nil
}
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, loaded := s.cache.Get(source.AddrPort())
if !loaded {