Merge pull request #953 from apernet/wip-udphop-listenudpfunc

feat: allow set ListenUDP impl for udphop conn
This commit is contained in:
Toby 2024-02-29 16:17:40 -08:00 committed by GitHub
commit 982be5498b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 24 additions and 15 deletions

View file

@ -179,7 +179,7 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval)
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil)
}
} else {
newFunc = func(addr net.Addr) (net.PacketConn, error) {

View file

@ -17,9 +17,10 @@ const (
)
type udpHopPacketConn struct {
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
Addr net.Addr
Addrs []net.Addr
HopInterval time.Duration
ListenUDPFunc ListenUDPFunc
connMutex sync.RWMutex
prevConn net.PacketConn
@ -43,29 +44,37 @@ type udpPacket struct {
Err error
}
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) {
type ListenUDPFunc func() (net.PacketConn, error)
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
if hopInterval == 0 {
hopInterval = defaultHopInterval
} else if hopInterval < 5*time.Second {
return nil, errors.New("hop interval must be at least 5 seconds")
}
if listenUDPFunc == nil {
listenUDPFunc = func() (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
}
addrs, err := addr.addrs()
if err != nil {
return nil, err
}
curConn, err := net.ListenUDP("udp", nil)
curConn, err := listenUDPFunc()
if err != nil {
return nil, err
}
hConn := &udpHopPacketConn{
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
Addr: addr,
Addrs: addrs,
HopInterval: hopInterval,
ListenUDPFunc: listenUDPFunc,
prevConn: nil,
currentConn: curConn,
addrIndex: rand.Intn(len(addrs)),
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
return make([]byte, udpBufferSize)
@ -121,7 +130,7 @@ func (u *udpHopPacketConn) hop() {
if u.closed {
return
}
newConn, err := net.ListenUDP("udp", nil)
newConn, err := u.ListenUDPFunc()
if err != nil {
// Could be temporary, just skip this hop
return