Improve bufio.NATPacketConn

This commit is contained in:
dyhkwong 2024-04-10 20:15:01 +08:00 committed by 世界
parent 8b68fc4d7a
commit f1b992c06f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -17,16 +17,16 @@ type NATPacketConn interface {
func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &unidirectionalNATPacketConn{ return &unidirectionalNATPacketConn{
NetPacketConn: conn, NetPacketConn: conn,
origin: origin, origin: socksaddrWithoutPort(origin),
destination: destination, destination: socksaddrWithoutPort(destination),
} }
} }
func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &bidirectionalNATPacketConn{ return &bidirectionalNATPacketConn{
NetPacketConn: conn, NetPacketConn: conn,
origin: origin, origin: socksaddrWithoutPort(origin),
destination: destination, destination: socksaddrWithoutPort(destination),
} }
} }
@ -37,15 +37,24 @@ type unidirectionalNATPacketConn struct {
} }
func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination { destination := M.SocksaddrFromNet(addr)
addr = c.origin.UDPAddr() if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
} }
return c.NetPacketConn.WriteTo(p, addr) return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
} }
func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination { if socksaddrWithoutPort(destination) == c.destination {
destination = c.origin destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
} }
return c.NetPacketConn.WritePacket(buffer, destination) return c.NetPacketConn.WritePacket(buffer, destination)
} }
@ -66,30 +75,55 @@ type bidirectionalNATPacketConn struct {
func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p) n, addr, err = c.NetPacketConn.ReadFrom(p)
if err == nil && M.SocksaddrFromNet(addr) == c.origin { if err != nil {
addr = c.destination.UDPAddr() return
} }
destination := M.SocksaddrFromNet(addr)
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
}
addr = destination.UDPAddr()
return return
} }
func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination { destination := M.SocksaddrFromNet(addr)
addr = c.origin.UDPAddr() if socksaddrWithoutPort(destination) == c.destination {
destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
} }
return c.NetPacketConn.WriteTo(p, addr) return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
} }
func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer) destination, err = c.NetPacketConn.ReadPacket(buffer)
if destination == c.origin { if err != nil {
destination = c.destination return
}
if socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,
Port: destination.Port,
}
} }
return return
} }
func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination { if socksaddrWithoutPort(destination) == c.destination {
destination = c.origin destination = M.Socksaddr{
Addr: c.origin.Addr,
Fqdn: c.origin.Fqdn,
Port: destination.Port,
}
} }
return c.NetPacketConn.WritePacket(buffer, destination) return c.NetPacketConn.WritePacket(buffer, destination)
} }
@ -101,3 +135,8 @@ func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.
func (c *bidirectionalNATPacketConn) Upstream() any { func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn return c.NetPacketConn
} }
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
}