diff --git a/common/bufio/nat.go b/common/bufio/nat.go index cafeb06..6e5ab64 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So } } +func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &destinationNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + type unidirectionalNATPacketConn struct { N.NetPacketConn origin M.Socksaddr @@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { return c.destination.UDPAddr() } +type destinationNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + if M.SocksaddrFromNet(addr) == c.origin { + addr = c.destination.UDPAddr() + } + return +} + +func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return + } + if destination == c.origin { + destination = c.destination + } + return +} + +func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *destinationNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *destinationNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { destination.Port = 0 return destination