Merge pull request #953 from apernet/wip-udphop-listenudpfunc

feat: allow set ListenUDP impl for udphop conn
This commit is contained in:
Toby 2024-02-29 16:17:40 -08:00 committed by GitHub
commit 982be5498b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 24 additions and 15 deletions

View file

@ -179,7 +179,7 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil)
}
} else {
newFunc = func(addr net.Addr) (net.PacketConn, error) {

View file

@ -20,6 +20,7 @@ type udpHopPacketConn struct {
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
ListenUDPFunc ListenUDPFunc
connMutex sync.RWMutex
prevConn net.PacketConn
@ -43,17 +44,24 @@ type udpPacket struct {
Err error
}
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
type ListenUDPFunc func() (net.PacketConn, error)
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
if hopInterval == 0 {
hopInterval = defaultHopInterval
} else if hopInterval < 5*time.Second {
return nil, errors.New("hop interval must be at least 5 seconds")
}
if listenUDPFunc == nil {
listenUDPFunc = func() (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
}
addrs, err := addr.addrs()
if err != nil {
return nil, err
}
curConn, err := net.ListenUDP("udp", nil)
curConn, err := listenUDPFunc()
if err != nil {
return nil, err
}
@ -61,6 +69,7 @@ func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.Packe
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
ListenUDPFunc: listenUDPFunc,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
@ -121,7 +130,7 @@ func (u *udpHopPacketConn) hop() {
if u.closed {
return
}
newConn, err := net.ListenUDP("udp", nil)
newConn, err := u.ListenUDPFunc()
if err != nil {
// Could be temporary, just skip this hop
return