diff --git a/sdk/client.go b/sdk/client.go index dc356f5..3b2ced4 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -26,8 +26,9 @@ const ( ) type ( - Protocol string - ResolveFunc func(network string, address string) (net.Addr, error) + Protocol string + ResolveFunc func(network string, address string) (net.Addr, error) + ListenUDPFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error) ) const ( @@ -75,6 +76,12 @@ type ClientConfig struct { // If not set, the default resolver will be used. ResolveFunc ResolveFunc + // ListenUDPFunc is the function used to listen on a UDP port. + // If not set, the default listener will be used. + // Please note that ProtocolFakeTCP does NOT use this function, + // as it is not a UDP-based protocol and has its own stack. + ListenUDPFunc ListenUDPFunc + // Protocol is the protocol to use. // It must be one of the following: // - ProtocolUDP @@ -163,6 +170,9 @@ func (c *ClientConfig) fill() { } } } + if c.ListenUDPFunc == nil { + c.ListenUDPFunc = net.ListenUDP + } if c.Protocol == "" { c.Protocol = ProtocolUDP } @@ -222,7 +232,7 @@ func NewClient(config ClientConfig) (Client, error) { if pff == nil { return nil, errors.New("unsupported protocol") } - pf := pff(config.Obfs, config.HopInterval, config.ResolveFunc) + pf := pff(config.Obfs, config.HopInterval, config.ResolveFunc, config.ListenUDPFunc) c, err := core.NewClient(config.ServerAddress, config.Auth, tlsConfig, quicConfig, pf, config.SendBPS, config.RecvBPS, nil) if err != nil { diff --git a/sdk/pktconns.go b/sdk/pktconns.go index 0254515..040195a 100644 --- a/sdk/pktconns.go +++ b/sdk/pktconns.go @@ -13,7 +13,8 @@ import ( ) type ( - clientPacketConnFuncFactory func(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc + clientPacketConnFuncFactory func(obfsPassword string, hopInterval time.Duration, + resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc) pktconns.ClientPacketConnFunc ) var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{ @@ -22,7 +23,9 @@ var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{ ProtocolFakeTCP: newClientFakeTCPConnFunc, } -func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { +func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, + resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc, +) pktconns.ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { if isMultiPortAddr(server) { @@ -32,7 +35,7 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv if err != nil { return nil, nil, err } - udpConn, err := net.ListenUDP("udp", nil) + udpConn, err := listenUDPFunc("udp", nil) return udpConn, sAddr, err } } else { @@ -45,7 +48,7 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv if err != nil { return nil, nil, err } - udpConn, err := net.ListenUDP("udp", nil) + udpConn, err := listenUDPFunc("udp", nil) if err != nil { return nil, nil, err } @@ -55,14 +58,16 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv } } -func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { +func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, + resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc, +) pktconns.ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { sAddr, err := resolveFunc("udp", server) if err != nil { return nil, nil, err } - udpConn, err := net.ListenUDP("udp", nil) + udpConn, err := listenUDPFunc("udp", nil) if err != nil { return nil, nil, err } @@ -74,7 +79,7 @@ func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, res if err != nil { return nil, nil, err } - udpConn, err := net.ListenUDP("udp", nil) + udpConn, err := listenUDPFunc("udp", nil) if err != nil { return nil, nil, err } @@ -84,7 +89,9 @@ func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, res } } -func newClientFakeTCPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { +func newClientFakeTCPConnFunc(obfsPassword string, hopInterval time.Duration, + resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc, +) pktconns.ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { sAddr, err := resolveFunc("tcp", server)