udpnat2: Implement set timeout for nat conn

This commit is contained in:
世界 2024-11-26 11:30:38 +08:00
parent 3613ead480
commit a8285e06a5
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 47 additions and 29 deletions

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) {
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
i.Update()
return i.Update()
}
func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration)
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
@ -24,10 +24,12 @@ type TimerPacketConn struct {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
if timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
return ctx, conn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
c.instance.SetTimeout(timeout)
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {

View file

@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now())
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {

View file

@ -4,9 +4,11 @@ import (
"io"
"net"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/canceler"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
@ -15,15 +17,18 @@ import (
type Conn interface {
N.PacketConn
SetHandler(handler N.UDPHandlerEx)
canceler.PacketConn
}
var _ Conn = (*natConn)(nil)
type natConn struct {
service *Service
writer N.PacketWriter
localAddr M.Socksaddr
handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
closeOnce sync.Once
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error
return c.writer.WritePacket(buffer, destination)
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
select {
case <-c.doneChan:
@ -68,31 +92,22 @@ fetch:
}
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
func (c *natConn) Timeout() time.Duration {
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort())
if !loaded || rawConn != c {
return 0
}
return time.Until(lifetime)
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
func (c *natConn) SetTimeout(timeout time.Duration) bool {
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
}
func (c *natConn) Close() error {
select {
case <-c.doneChan:
default:
c.closeOnce.Do(func() {
close(c.doneChan)
}
})
return nil
}

View file

@ -65,6 +65,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
return
}
conn = &natConn{
service: s,
writer: writer,
localAddr: source,
packetChan: make(chan *N.PacketBuffer, 64),