feat: fix outbound

This commit is contained in:
Toby 2023-07-27 19:24:43 -07:00
parent d4e3833641
commit e381c2eae8
3 changed files with 113 additions and 16 deletions

View file

@ -29,7 +29,6 @@ type directOutbound struct {
DeviceName string // For UDP binding DeviceName string // For UDP binding
} }
/*
// NewDirectOutboundSimple creates a new directOutbound with the given mode, // NewDirectOutboundSimple creates a new directOutbound with the given mode,
// without binding to a specific device. Works on all platforms. // without binding to a specific device. Works on all platforms.
func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound { func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound {
@ -40,7 +39,6 @@ func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound {
}, },
} }
} }
*/
// resolve is our built-in DNS resolver for handling the case when // resolve is our built-in DNS resolver for handling the case when
// AddrEx.ResolveInfo is nil. // AddrEx.ResolveInfo is nil.
@ -51,23 +49,14 @@ func (d *directOutbound) resolve(reqAddr *AddrEx) {
return return
} }
r := &ResolveInfo{} r := &ResolveInfo{}
for _, ip := range ips { r.IPv4, r.IPv6 = splitIPv4IPv6(ips)
if r.IPv4 == nil && ip.To4() != nil { if r.IPv4 == nil && r.IPv6 == nil {
r.IPv4 = ip r.Err = errors.New("no IPv4 or IPv6 address available")
}
if r.IPv6 == nil && ip.To4() == nil {
// We must NOT use ip.To16() here because it will always
// return a 16-byte slice, even if the original IP is IPv4.
r.IPv6 = ip
}
if r.IPv4 != nil && r.IPv6 != nil {
break
}
} }
reqAddr.ResolveInfo = r reqAddr.ResolveInfo = r
} }
func (d *directOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) { func (d *directOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) {
if reqAddr.ResolveInfo == nil { if reqAddr.ResolveInfo == nil {
// AddrEx.ResolveInfo is nil (no resolver in the pipeline), // AddrEx.ResolveInfo is nil (no resolver in the pipeline),
// we need to resolve the address ourselves. // we need to resolve the address ourselves.
@ -252,7 +241,7 @@ func (u *directOutboundUDPConn) Close() error {
return u.UDPConn.Close() return u.UDPConn.Close()
} }
func (d *directOutbound) ListenUDP() (UDPConn, error) { func (d *directOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) {
c, err := net.ListenUDP("udp", nil) c, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
return nil, err return nil, err

24
extras/outbounds/utils.go Normal file
View file

@ -0,0 +1,24 @@
package outbounds
import "net"
// splitIPv4IPv6 gets the first IPv4 and IPv6 address from a list of IP addresses.
// Both of the return values can be nil when no IPv4 or IPv6 address is found.
func splitIPv4IPv6(ips []net.IP) (ipv4, ipv6 net.IP) {
for _, ip := range ips {
if ip.To4() != nil {
if ipv4 == nil {
ipv4 = ip
}
} else {
if ipv6 == nil {
ipv6 = ip
}
}
if ipv4 != nil && ipv6 != nil {
// We have everything we need.
break
}
}
return
}

View file

@ -0,0 +1,84 @@
package outbounds
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSplitIPv4IPv6(t *testing.T) {
type args struct {
ips []net.IP
}
tests := []struct {
name string
args args
wantIpv4 net.IP
wantIpv6 net.IP
}{
{
name: "IPv4 only",
args: args{
ips: []net.IP{
net.ParseIP("4.5.6.7"),
net.ParseIP("9.9.9.9"),
},
},
wantIpv4: net.ParseIP("4.5.6.7"),
wantIpv6: nil,
},
{
name: "IPv6 only",
args: args{
ips: []net.IP{
net.ParseIP("2001:db8::68"),
net.ParseIP("2001:db8::69"),
},
},
wantIpv4: nil,
wantIpv6: net.ParseIP("2001:db8::68"),
},
{
name: "Both 1",
args: args{
ips: []net.IP{
net.ParseIP("2001:db8::68"),
net.ParseIP("2001:db8::69"),
net.ParseIP("4.5.6.7"),
net.ParseIP("9.9.9.9"),
},
},
wantIpv4: net.ParseIP("4.5.6.7"),
wantIpv6: net.ParseIP("2001:db8::68"),
},
{
name: "Both 2",
args: args{
ips: []net.IP{
net.ParseIP("2001:db8::69"),
net.ParseIP("9.9.9.9"),
net.ParseIP("2001:db8::68"),
net.ParseIP("4.5.6.7"),
},
},
wantIpv4: net.ParseIP("9.9.9.9"),
wantIpv6: net.ParseIP("2001:db8::69"),
},
{
name: "Empty",
args: args{
ips: []net.IP{},
},
wantIpv4: nil,
wantIpv6: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIpv4, gotIpv6 := splitIPv4IPv6(tt.args.ips)
assert.Equalf(t, tt.wantIpv4, gotIpv4, "splitIPv4IPv6(%v)", tt.args.ips)
assert.Equalf(t, tt.wantIpv6, gotIpv6, "splitIPv4IPv6(%v)", tt.args.ips)
})
}
}