diff --git a/stack.go b/stack.go index 3f7f357..f6affa0 100644 --- a/stack.go +++ b/stack.go @@ -23,7 +23,6 @@ type StackOptions struct { Context context.Context Tun Tun TunOptions Options - EndpointIndependentNat bool UDPTimeout time.Duration Handler Handler Logger logger.Logger diff --git a/stack_gvisor.go b/stack_gvisor.go index 6c5c27f..60af865 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -17,9 +17,6 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/gvisor/pkg/waiter" - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -31,15 +28,14 @@ const WithGVisor = true const defaultNIC tcpip.NICID = 1 type GVisor struct { - ctx context.Context - tun GVisorTun - endpointIndependentNat bool - udpTimeout time.Duration - broadcastAddr netip.Addr - handler Handler - logger logger.Logger - stack *stack.Stack - endpoint stack.LinkEndpoint + ctx context.Context + tun GVisorTun + udpTimeout time.Duration + broadcastAddr netip.Addr + handler Handler + logger logger.Logger + stack *stack.Stack + endpoint stack.LinkEndpoint } type GVisorTun interface { @@ -56,13 +52,12 @@ func NewGVisor( } gStack := &GVisor{ - ctx: options.Context, - tun: gTun, - endpointIndependentNat: options.EndpointIndependentNat, - udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), - handler: options.Handler, - logger: options.Logger, + ctx: options.Context, + tun: gTun, + udpTimeout: options.UDPTimeout, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), + handler: options.Handler, + logger: options.Logger, } return gStack, nil } @@ -95,31 +90,7 @@ func (t *GVisor) Start() error { go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - if !t.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := t.handler.PrepareConnection(N.NetworkUDP, source, destination) - if pErr != nil { - gWriteUnreachable(t.stack, r.Packet(), err) - r.Packet().DecRef() - return - } - var wq waiter.Queue - endpoint, err := r.CreateEndpoint(&wq) - if err != nil { - return - } - go func() { - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), t.udpTimeout) - t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) - }() - }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - } else { - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) - } - + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint return nil diff --git a/stack_mixed.go b/stack_mixed.go index 8388cb9..3b7314e 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -5,25 +5,19 @@ package tun import ( "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" - "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" gHdr "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/link/channel" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/gvisor/pkg/waiter" "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) type Mixed struct { *System - endpointIndependentNat bool - stack *stack.Stack - endpoint *channel.Endpoint + stack *stack.Stack + endpoint *channel.Endpoint } func NewMixed( @@ -34,8 +28,7 @@ func NewMixed( return nil, err } return &Mixed{ - System: system.(*System), - endpointIndependentNat: options.EndpointIndependentNat, + System: system.(*System), }, nil } @@ -49,30 +42,7 @@ func (m *Mixed) Start() error { if err != nil { return err } - if !m.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { - source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := m.handler.PrepareConnection(N.NetworkUDP, source, destination) - if pErr != nil { - gWriteUnreachable(m.stack, r.Packet(), err) - r.Packet().DecRef() - return - } - var wq waiter.Queue - endpoint, err := r.CreateEndpoint(&wq) - if err != nil { - return - } - go func() { - ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), m.udpTimeout) - m.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) - }() - }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - } else { - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) - } + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) m.stack = ipStack m.endpoint = endpoint go m.tunLoop() diff --git a/stack_system.go b/stack_system.go index e3812ed..2baa0c6 100644 --- a/stack_system.go +++ b/stack_system.go @@ -731,7 +731,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S newPacket.Write(buffer.Bytes()) ipHdr := header.IPv4(newPacket.Bytes()) ipHdr.SetTotalLength(uint16(newPacket.Len())) - ipHdr.SetSourceAddress(ipHdr.SourceAddress()) + ipHdr.SetDestinationAddress(ipHdr.SourceAddress()) ipHdr.SetSourceAddr(destination.Addr) udpHdr := header.UDP(ipHdr.Payload()) udpHdr.SetDestinationPort(udpHdr.SourcePort())