wip: sdk listener hook

This commit is contained in:
Toby 2022-11-23 00:57:14 +00:00
parent c8c8aa61aa
commit 6f89dc34b0
2 changed files with 28 additions and 11 deletions

View file

@ -28,6 +28,7 @@ const (
type ( type (
Protocol string Protocol string
ResolveFunc func(network string, address string) (net.Addr, error) ResolveFunc func(network string, address string) (net.Addr, error)
ListenUDPFunc func(network string, laddr *net.UDPAddr) (*net.UDPConn, error)
) )
const ( const (
@ -75,6 +76,12 @@ type ClientConfig struct {
// If not set, the default resolver will be used. // If not set, the default resolver will be used.
ResolveFunc ResolveFunc ResolveFunc ResolveFunc
// ListenUDPFunc is the function used to listen on a UDP port.
// If not set, the default listener will be used.
// Please note that ProtocolFakeTCP does NOT use this function,
// as it is not a UDP-based protocol and has its own stack.
ListenUDPFunc ListenUDPFunc
// Protocol is the protocol to use. // Protocol is the protocol to use.
// It must be one of the following: // It must be one of the following:
// - ProtocolUDP // - ProtocolUDP
@ -163,6 +170,9 @@ func (c *ClientConfig) fill() {
} }
} }
} }
if c.ListenUDPFunc == nil {
c.ListenUDPFunc = net.ListenUDP
}
if c.Protocol == "" { if c.Protocol == "" {
c.Protocol = ProtocolUDP c.Protocol = ProtocolUDP
} }
@ -222,7 +232,7 @@ func NewClient(config ClientConfig) (Client, error) {
if pff == nil { if pff == nil {
return nil, errors.New("unsupported protocol") return nil, errors.New("unsupported protocol")
} }
pf := pff(config.Obfs, config.HopInterval, config.ResolveFunc) pf := pff(config.Obfs, config.HopInterval, config.ResolveFunc, config.ListenUDPFunc)
c, err := core.NewClient(config.ServerAddress, config.Auth, tlsConfig, quicConfig, pf, c, err := core.NewClient(config.ServerAddress, config.Auth, tlsConfig, quicConfig, pf,
config.SendBPS, config.RecvBPS, nil) config.SendBPS, config.RecvBPS, nil)
if err != nil { if err != nil {

View file

@ -13,7 +13,8 @@ import (
) )
type ( type (
clientPacketConnFuncFactory func(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc clientPacketConnFuncFactory func(obfsPassword string, hopInterval time.Duration,
resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc) pktconns.ClientPacketConnFunc
) )
var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{ var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{
@ -22,7 +23,9 @@ var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{
ProtocolFakeTCP: newClientFakeTCPConnFunc, ProtocolFakeTCP: newClientFakeTCPConnFunc,
} }
func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration,
resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc,
) pktconns.ClientPacketConnFunc {
if obfsPassword == "" { if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
if isMultiPortAddr(server) { if isMultiPortAddr(server) {
@ -32,7 +35,7 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := listenUDPFunc("udp", nil)
return udpConn, sAddr, err return udpConn, sAddr, err
} }
} else { } else {
@ -45,7 +48,7 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := listenUDPFunc("udp", nil)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -55,14 +58,16 @@ func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolv
} }
} }
func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration,
resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc,
) pktconns.ClientPacketConnFunc {
if obfsPassword == "" { if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
sAddr, err := resolveFunc("udp", server) sAddr, err := resolveFunc("udp", server)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := listenUDPFunc("udp", nil)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -74,7 +79,7 @@ func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, res
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := listenUDPFunc("udp", nil)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -84,7 +89,9 @@ func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, res
} }
} }
func newClientFakeTCPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { func newClientFakeTCPConnFunc(obfsPassword string, hopInterval time.Duration,
resolveFunc ResolveFunc, listenUDPFunc ListenUDPFunc,
) pktconns.ClientPacketConnFunc {
if obfsPassword == "" { if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) { return func(server string) (net.PacketConn, net.Addr, error) {
sAddr, err := resolveFunc("tcp", server) sAddr, err := resolveFunc("tcp", server)