From 64a6fb2edd3fbbbfb85d63e826968f791d19985c Mon Sep 17 00:00:00 2001 From: Toby Date: Tue, 22 Nov 2022 18:29:35 -0800 Subject: [PATCH] feat: resolve & listen func for udp hop --- pkg/pktconns/funcs.go | 4 ++-- pkg/pktconns/udp/hop.go | 15 ++++++++++----- sdk/client.go | 4 +++- sdk/pktconns.go | 19 +++++++++++++++++-- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pkg/pktconns/funcs.go b/pkg/pktconns/funcs.go index 7bdc409..3071329 100644 --- a/pkg/pktconns/funcs.go +++ b/pkg/pktconns/funcs.go @@ -41,7 +41,7 @@ func NewClientUDPConnFunc(obfsPassword string, hopInterval time.Duration) Client if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { if isMultiPortAddr(server) { - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil) + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, net.ResolveIPAddr, net.ListenUDP) } sAddr, err := net.ResolveUDPAddr("udp", server) if err != nil { @@ -54,7 +54,7 @@ func NewClientUDPConnFunc(obfsPassword string, hopInterval time.Duration) Client return func(server string) (net.PacketConn, net.Addr, error) { if isMultiPortAddr(server) { ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob) + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob, net.ResolveIPAddr, net.ListenUDP) } sAddr, err := net.ResolveUDPAddr("udp", server) if err != nil { diff --git a/pkg/pktconns/udp/hop.go b/pkg/pktconns/udp/hop.go index e1e4f9e..47b944c 100644 --- a/pkg/pktconns/udp/hop.go +++ b/pkg/pktconns/udp/hop.go @@ -24,7 +24,8 @@ type ObfsUDPHopClientPacketConn struct { serverAddrs []net.Addr hopInterval time.Duration - obfs obfs.Obfuscator + obfs obfs.Obfuscator + listenFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error) connMutex sync.RWMutex prevConn net.PacketConn @@ -57,13 +58,16 @@ type udpPacket struct { addr net.Addr } -func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator) (*ObfsUDPHopClientPacketConn, net.Addr, error) { +func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator, + resolveFunc func(network string, address string) (*net.IPAddr, error), + listenFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error), +) (*ObfsUDPHopClientPacketConn, net.Addr, error) { host, ports, err := parseAddr(server) if err != nil { return nil, nil, err } // Resolve the server IP address, then attach the ports to UDP addresses - ip, err := net.ResolveIPAddr("ip", host) + ip, err := resolveFunc("ip", host) if err != nil { return nil, nil, err } @@ -80,6 +84,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf serverAddrs: serverAddrs, hopInterval: hopInterval, obfs: obfs, + listenFunc: listenFunc, addrIndex: rand.Intn(len(serverAddrs)), recvQueue: make(chan *udpPacket, packetQueueSize), closeChan: make(chan struct{}), @@ -89,7 +94,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf }, }, } - curConn, err := net.ListenUDP("udp", nil) + curConn, err := listenFunc("udp", nil) if err != nil { return nil, nil, err } @@ -138,7 +143,7 @@ func (c *ObfsUDPHopClientPacketConn) hop() { if c.closed { return } - newConn, err := net.ListenUDP("udp", nil) + newConn, err := c.listenFunc("udp", nil) if err != nil { // Skip this hop if failed to listen return diff --git a/sdk/client.go b/sdk/client.go index 3b2ced4..0144f11 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -156,7 +156,7 @@ type ClientConfig struct { QUICConfig *quic.Config } -// fill fills in the default values (if not set) for the configuration. +// fill in the default values (if not set) for the configuration. func (c *ClientConfig) fill() { if c.ResolveFunc == nil { c.ResolveFunc = func(network string, address string) (net.Addr, error) { @@ -165,6 +165,8 @@ func (c *ClientConfig) fill() { return net.ResolveTCPAddr(network, address) case "udp", "udp4", "udp6": return net.ResolveUDPAddr(network, address) + case "ip", "ip4", "ip6": + return net.ResolveIPAddr(network, address) default: return nil, errors.New("unsupported network type") } diff --git a/sdk/pktconns.go b/sdk/pktconns.go index 040195a..a75dc9e 100644 --- a/sdk/pktconns.go +++ b/sdk/pktconns.go @@ -23,13 +23,27 @@ var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{ ProtocolFakeTCP: newClientFakeTCPConnFunc, } +func resolveFuncToIPResolveFunc(resolveFunc ResolveFunc) func(network string, address string) (*net.IPAddr, error) { + return func(network string, address string) (*net.IPAddr, error) { + addr, err := resolveFunc(network, address) + if err != nil { + return nil, err + } + if ipAddr, ok := addr.(*net.IPAddr); ok { + return ipAddr, nil + } + return nil, net.InvalidAddrError("not an IP address") + } +} + func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc, ) pktconns.ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { if isMultiPortAddr(server) { - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil) + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, + resolveFuncToIPResolveFunc(resolveFunc), listenUDPFunc) } sAddr, err := resolveFunc("udp", server) if err != nil { @@ -42,7 +56,8 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, return func(server string) (net.PacketConn, net.Addr, error) { if isMultiPortAddr(server) { ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob) + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob, + resolveFuncToIPResolveFunc(resolveFunc), listenUDPFunc) } sAddr, err := resolveFunc("udp", server) if err != nil {