mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
udpnat2: Implement set timeout for nat conn
This commit is contained in:
parent
3613ead480
commit
a8285e06a5
5 changed files with 47 additions and 29 deletions
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
||||||
return i.timeout
|
return i.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) SetTimeout(timeout time.Duration) {
|
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
||||||
i.timeout = timeout
|
i.timeout = timeout
|
||||||
i.Update()
|
return i.Update()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) wait() {
|
func (i *Instance) wait() {
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type PacketConn interface {
|
type PacketConn interface {
|
||||||
N.PacketConn
|
N.PacketConn
|
||||||
Timeout() time.Duration
|
Timeout() time.Duration
|
||||||
SetTimeout(timeout time.Duration)
|
SetTimeout(timeout time.Duration) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type TimerPacketConn struct {
|
type TimerPacketConn struct {
|
||||||
|
@ -24,11 +24,13 @@ type TimerPacketConn struct {
|
||||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||||
oldTimeout := timeoutConn.Timeout()
|
oldTimeout := timeoutConn.Timeout()
|
||||||
if timeout < oldTimeout {
|
if timeout >= oldTimeout {
|
||||||
timeoutConn.SetTimeout(timeout)
|
|
||||||
}
|
|
||||||
return ctx, conn
|
return ctx, conn
|
||||||
}
|
}
|
||||||
|
if timeoutConn.SetTimeout(timeout) {
|
||||||
|
return ctx, conn
|
||||||
|
}
|
||||||
|
}
|
||||||
err := conn.SetReadDeadline(time.Time{})
|
err := conn.SetReadDeadline(time.Time{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return NewTimeoutPacketConn(ctx, conn, timeout)
|
return NewTimeoutPacketConn(ctx, conn, timeout)
|
||||||
|
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
||||||
return c.instance.Timeout()
|
return c.instance.Timeout()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||||
c.instance.SetTimeout(timeout)
|
return c.instance.SetTimeout(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) Close() error {
|
func (c *TimerPacketConn) Close() error {
|
||||||
|
|
|
@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
||||||
return c.timeout
|
return c.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
||||||
c.timeout = timeout
|
c.timeout = timeout
|
||||||
c.PacketConn.SetReadDeadline(time.Now())
|
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) Close() error {
|
func (c *TimeoutPacketConn) Close() error {
|
||||||
|
|
|
@ -4,9 +4,11 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/canceler"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/pipe"
|
"github.com/sagernet/sing/common/pipe"
|
||||||
|
@ -15,15 +17,18 @@ import (
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
N.PacketConn
|
N.PacketConn
|
||||||
SetHandler(handler N.UDPHandlerEx)
|
SetHandler(handler N.UDPHandlerEx)
|
||||||
|
canceler.PacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Conn = (*natConn)(nil)
|
var _ Conn = (*natConn)(nil)
|
||||||
|
|
||||||
type natConn struct {
|
type natConn struct {
|
||||||
|
service *Service
|
||||||
writer N.PacketWriter
|
writer N.PacketWriter
|
||||||
localAddr M.Socksaddr
|
localAddr M.Socksaddr
|
||||||
handler N.UDPHandlerEx
|
handler N.UDPHandlerEx
|
||||||
packetChan chan *N.PacketBuffer
|
packetChan chan *N.PacketBuffer
|
||||||
|
closeOnce sync.Once
|
||||||
doneChan chan struct{}
|
doneChan chan struct{}
|
||||||
readDeadline pipe.Deadline
|
readDeadline pipe.Deadline
|
||||||
readWaitOptions N.ReadWaitOptions
|
readWaitOptions N.ReadWaitOptions
|
||||||
|
@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error
|
||||||
return c.writer.WritePacket(buffer, destination)
|
return c.writer.WritePacket(buffer, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||||
|
c.readWaitOptions = options
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
|
select {
|
||||||
|
case packet := <-c.packetChan:
|
||||||
|
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
||||||
|
destination = packet.Destination
|
||||||
|
N.PutPacketBuffer(packet)
|
||||||
|
return
|
||||||
|
case <-c.doneChan:
|
||||||
|
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
||||||
|
case <-c.readDeadline.Wait():
|
||||||
|
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
|
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
|
||||||
select {
|
select {
|
||||||
case <-c.doneChan:
|
case <-c.doneChan:
|
||||||
|
@ -68,31 +92,22 @@ fetch:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
func (c *natConn) Timeout() time.Duration {
|
||||||
c.readWaitOptions = options
|
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
||||||
return false
|
if !loaded || rawConn != c {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Until(lifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
func (c *natConn) SetTimeout(timeout time.Duration) bool {
|
||||||
select {
|
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
||||||
case packet := <-c.packetChan:
|
|
||||||
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
|
||||||
destination = packet.Destination
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
return
|
|
||||||
case <-c.doneChan:
|
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *natConn) Close() error {
|
func (c *natConn) Close() error {
|
||||||
select {
|
c.closeOnce.Do(func() {
|
||||||
case <-c.doneChan:
|
|
||||||
default:
|
|
||||||
close(c.doneChan)
|
close(c.doneChan)
|
||||||
}
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn = &natConn{
|
conn = &natConn{
|
||||||
|
service: s,
|
||||||
writer: writer,
|
writer: writer,
|
||||||
localAddr: source,
|
localAddr: source,
|
||||||
packetChan: make(chan *N.PacketBuffer, 64),
|
packetChan: make(chan *N.PacketBuffer, 64),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue