feat: resolve & listen func for udp hop

This commit is contained in:
Toby 2022-11-22 18:29:35 -08:00
parent 6f89dc34b0
commit 64a6fb2edd
4 changed files with 32 additions and 10 deletions

View file

@ -41,7 +41,7 @@ func NewClientUDPConnFunc(obfsPassword string, hopInterval time.Duration) Client
if obfsPassword == "" { if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
if isMultiPortAddr(server) { if isMultiPortAddr(server) {
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil) return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, net.ResolveIPAddr, net.ListenUDP)
} }
sAddr, err := net.ResolveUDPAddr("udp", server) sAddr, err := net.ResolveUDPAddr("udp", server)
if err != nil { if err != nil {
@ -54,7 +54,7 @@ func NewClientUDPConnFunc(obfsPassword string, hopInterval time.Duration) Client
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
if isMultiPortAddr(server) { if isMultiPortAddr(server) {
ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) ob := obfs.NewXPlusObfuscator([]byte(obfsPassword))
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob) return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob, net.ResolveIPAddr, net.ListenUDP)
} }
sAddr, err := net.ResolveUDPAddr("udp", server) sAddr, err := net.ResolveUDPAddr("udp", server)
if err != nil { if err != nil {

View file

@ -24,7 +24,8 @@ type ObfsUDPHopClientPacketConn struct {
serverAddrs []net.Addr serverAddrs []net.Addr
hopInterval time.Duration hopInterval time.Duration
obfs obfs.Obfuscator obfs obfs.Obfuscator
listenFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error)
connMutex sync.RWMutex connMutex sync.RWMutex
prevConn net.PacketConn prevConn net.PacketConn
@ -57,13 +58,16 @@ type udpPacket struct {
addr net.Addr addr net.Addr
} }
func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator) (*ObfsUDPHopClientPacketConn, net.Addr, error) { func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator,
resolveFunc func(network string, address string) (*net.IPAddr, error),
listenFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error),
) (*ObfsUDPHopClientPacketConn, net.Addr, error) {
host, ports, err := parseAddr(server) host, ports, err := parseAddr(server)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Resolve the server IP address, then attach the ports to UDP addresses // Resolve the server IP address, then attach the ports to UDP addresses
ip, err := net.ResolveIPAddr("ip", host) ip, err := resolveFunc("ip", host)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -80,6 +84,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf
serverAddrs: serverAddrs, serverAddrs: serverAddrs,
hopInterval: hopInterval, hopInterval: hopInterval,
obfs: obfs, obfs: obfs,
listenFunc: listenFunc,
addrIndex: rand.Intn(len(serverAddrs)), addrIndex: rand.Intn(len(serverAddrs)),
recvQueue: make(chan *udpPacket, packetQueueSize), recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}), closeChan: make(chan struct{}),
@ -89,7 +94,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf
}, },
}, },
} }
curConn, err := net.ListenUDP("udp", nil) curConn, err := listenFunc("udp", nil)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -138,7 +143,7 @@ func (c *ObfsUDPHopClientPacketConn) hop() {
if c.closed { if c.closed {
return return
} }
newConn, err := net.ListenUDP("udp", nil) newConn, err := c.listenFunc("udp", nil)
if err != nil { if err != nil {
// Skip this hop if failed to listen // Skip this hop if failed to listen
return return

View file

@ -156,7 +156,7 @@ type ClientConfig struct {
QUICConfig *quic.Config QUICConfig *quic.Config
} }
// fill fills in the default values (if not set) for the configuration. // fill in the default values (if not set) for the configuration.
func (c *ClientConfig) fill() { func (c *ClientConfig) fill() {
if c.ResolveFunc == nil { if c.ResolveFunc == nil {
c.ResolveFunc = func(network string, address string) (net.Addr, error) { c.ResolveFunc = func(network string, address string) (net.Addr, error) {
@ -165,6 +165,8 @@ func (c *ClientConfig) fill() {
return net.ResolveTCPAddr(network, address) return net.ResolveTCPAddr(network, address)
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
return net.ResolveUDPAddr(network, address) return net.ResolveUDPAddr(network, address)
case "ip", "ip4", "ip6":
return net.ResolveIPAddr(network, address)
default: default:
return nil, errors.New("unsupported network type") return nil, errors.New("unsupported network type")
} }

View file

@ -23,13 +23,27 @@ var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{
ProtocolFakeTCP: newClientFakeTCPConnFunc, ProtocolFakeTCP: newClientFakeTCPConnFunc,
} }
func resolveFuncToIPResolveFunc(resolveFunc ResolveFunc) func(network string, address string) (*net.IPAddr, error) {
return func(network string, address string) (*net.IPAddr, error) {
addr, err := resolveFunc(network, address)
if err != nil {
return nil, err
}
if ipAddr, ok := addr.(*net.IPAddr); ok {
return ipAddr, nil
}
return nil, net.InvalidAddrError("not an IP address")
}
}
func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration,
resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc, resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc,
) pktconns.ClientPacketConnFunc { ) pktconns.ClientPacketConnFunc {
if obfsPassword == "" { if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
if isMultiPortAddr(server) { if isMultiPortAddr(server) {
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil) return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil,
resolveFuncToIPResolveFunc(resolveFunc), listenUDPFunc)
} }
sAddr, err := resolveFunc("udp", server) sAddr, err := resolveFunc("udp", server)
if err != nil { if err != nil {
@ -42,7 +56,8 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration,
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
if isMultiPortAddr(server) { if isMultiPortAddr(server) {
ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) ob := obfs.NewXPlusObfuscator([]byte(obfsPassword))
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob) return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob,
resolveFuncToIPResolveFunc(resolveFunc), listenUDPFunc)
} }
sAddr, err := resolveFunc("udp", server) sAddr, err := resolveFunc("udp", server)
if err != nil { if err != nil {