From afcc9cb766c27d0b11c2f6324077ad551e520b4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Oct 2023 20:59:48 +0800 Subject: [PATCH] Add unidirectional NATPacketConn --- common/bufio/nat.go | 63 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/common/bufio/nat.go b/common/bufio/nat.go index d652094..43e8d40 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -9,21 +9,62 @@ import ( N "github.com/sagernet/sing/common/network" ) -type NATPacketConn struct { +type NATPacketConn interface { N.NetPacketConn - origin M.Socksaddr - destination M.Socksaddr + UpdateDestination(destinationAddress netip.Addr) } -func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn { - return &NATPacketConn{ +func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &unidirectionalNATPacketConn{ NetPacketConn: conn, origin: origin, destination: destination, } } -func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &bidirectionalNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + +type unidirectionalNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *unidirectionalNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +type bidirectionalNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.NetPacketConn.ReadFrom(p) if err == nil && M.SocksaddrFromNet(addr) == c.origin { addr = c.destination.UDPAddr() @@ -31,14 +72,14 @@ func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return } -func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if M.SocksaddrFromNet(addr) == c.destination { addr = c.origin.UDPAddr() } return c.NetPacketConn.WriteTo(p, addr) } -func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { +func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { destination, err = c.NetPacketConn.ReadPacket(buffer) if destination == c.origin { destination = c.destination @@ -46,17 +87,17 @@ func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, return } -func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if destination == c.destination { destination = c.origin } return c.NetPacketConn.WritePacket(buffer, destination) } -func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) { +func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } -func (c *NATPacketConn) Upstream() any { +func (c *bidirectionalNATPacketConn) Upstream() any { return c.NetPacketConn }