mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-02 03:17:37 +03:00
udpnat2: Implement set timeout for nat conn
This commit is contained in:
parent
3613ead480
commit
a8285e06a5
5 changed files with 47 additions and 29 deletions
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
|||
return i.timeout
|
||||
}
|
||||
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) {
|
||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
||||
i.timeout = timeout
|
||||
i.Update()
|
||||
return i.Update()
|
||||
}
|
||||
|
||||
func (i *Instance) wait() {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type PacketConn interface {
|
||||
N.PacketConn
|
||||
Timeout() time.Duration
|
||||
SetTimeout(timeout time.Duration)
|
||||
SetTimeout(timeout time.Duration) bool
|
||||
}
|
||||
|
||||
type TimerPacketConn struct {
|
||||
|
@ -24,10 +24,12 @@ type TimerPacketConn struct {
|
|||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||
oldTimeout := timeoutConn.Timeout()
|
||||
if timeout < oldTimeout {
|
||||
timeoutConn.SetTimeout(timeout)
|
||||
if timeout >= oldTimeout {
|
||||
return ctx, conn
|
||||
}
|
||||
if timeoutConn.SetTimeout(timeout) {
|
||||
return ctx, conn
|
||||
}
|
||||
return ctx, conn
|
||||
}
|
||||
err := conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
|
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
|||
return c.instance.Timeout()
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
||||
c.instance.SetTimeout(timeout)
|
||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.instance.SetTimeout(timeout)
|
||||
}
|
||||
|
||||
func (c *TimerPacketConn) Close() error {
|
||||
|
|
|
@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
|||
return c.timeout
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||
c.timeout = timeout
|
||||
c.PacketConn.SetReadDeadline(time.Now())
|
||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
||||
}
|
||||
|
||||
func (c *TimeoutPacketConn) Close() error {
|
||||
|
|
|
@ -4,9 +4,11 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/canceler"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
|
@ -15,15 +17,18 @@ import (
|
|||
type Conn interface {
|
||||
N.PacketConn
|
||||
SetHandler(handler N.UDPHandlerEx)
|
||||
canceler.PacketConn
|
||||
}
|
||||
|
||||
var _ Conn = (*natConn)(nil)
|
||||
|
||||
type natConn struct {
|
||||
service *Service
|
||||
writer N.PacketWriter
|
||||
localAddr M.Socksaddr
|
||||
handler N.UDPHandlerEx
|
||||
packetChan chan *N.PacketBuffer
|
||||
closeOnce sync.Once
|
||||
doneChan chan struct{}
|
||||
readDeadline pipe.Deadline
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
|
@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error
|
|||
return c.writer.WritePacket(buffer, destination)
|
||||
}
|
||||
|
||||
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case packet := <-c.packetChan:
|
||||
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
||||
destination = packet.Destination
|
||||
N.PutPacketBuffer(packet)
|
||||
return
|
||||
case <-c.doneChan:
|
||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||
case <-c.readDeadline.Wait():
|
||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
|
||||
select {
|
||||
case <-c.doneChan:
|
||||
|
@ -68,31 +92,22 @@ fetch:
|
|||
}
|
||||
}
|
||||
|
||||
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
func (c *natConn) Timeout() time.Duration {
|
||||
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
||||
if !loaded || rawConn != c {
|
||||
return 0
|
||||
}
|
||||
return time.Until(lifetime)
|
||||
}
|
||||
|
||||
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case packet := <-c.packetChan:
|
||||
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
||||
destination = packet.Destination
|
||||
N.PutPacketBuffer(packet)
|
||||
return
|
||||
case <-c.doneChan:
|
||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||
case <-c.readDeadline.Wait():
|
||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
func (c *natConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
||||
}
|
||||
|
||||
func (c *natConn) Close() error {
|
||||
select {
|
||||
case <-c.doneChan:
|
||||
default:
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.doneChan)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
|
|||
return
|
||||
}
|
||||
conn = &natConn{
|
||||
service: s,
|
||||
writer: writer,
|
||||
localAddr: source,
|
||||
packetChan: make(chan *N.PacketBuffer, 64),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue