From 8311d6e9709c974877320cd178759153684c2037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 22 Jul 2022 16:09:32 +0800 Subject: [PATCH] Improve 4in6 processing (break change) --- common/metadata/addr.go | 41 ++++++++++++----------------------- common/metadata/serializer.go | 9 +++----- common/network/multi.go | 4 ++-- protocol/socks/handshake.go | 2 +- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/common/metadata/addr.go b/common/metadata/addr.go index eb17a95..494b09f 100644 --- a/common/metadata/addr.go +++ b/common/metadata/addr.go @@ -22,11 +22,21 @@ func (ap Socksaddr) IsIP() bool { } func (ap Socksaddr) IsIPv4() bool { - return ap.Addr.Is4() + return ap.Addr.Is4() || ap.Addr.Is4In6() } func (ap Socksaddr) IsIPv6() bool { - return ap.Addr.Is6() + return ap.Addr.Is6() && !ap.Addr.Is4In6() +} + +func (ap Socksaddr) Unwrap() Socksaddr { + if ap.Addr.Is4In6() { + return Socksaddr{ + Addr: netip.AddrFrom4(ap.Addr.As4()), + Port: ap.Port, + } + } + return ap } func (ap Socksaddr) IsFqdn() bool { @@ -88,25 +98,14 @@ func UDPAddr(ap netip.AddrPort) *net.UDPAddr { } func AddrPortFrom(ip net.IP, port uint16) netip.AddrPort { - addr, _ := netip.AddrFromSlice(ip) - return netip.AddrPortFrom(addr, port) + return netip.AddrPortFrom(AddrFromIP(ip), port) } -func SocksaddrFrom(ip net.IP, port uint16) Socksaddr { - return SocksaddrFromNetIP(AddrPortFrom(ip, port)) -} - -func SocksaddrFromAddrPort(addr netip.Addr, port uint16) Socksaddr { +func SocksaddrFrom(addr netip.Addr, port uint16) Socksaddr { return SocksaddrFromNetIP(netip.AddrPortFrom(addr, port)) } func SocksaddrFromNetIP(ap netip.AddrPort) Socksaddr { - if ap.Addr().Is4In6() { - return Socksaddr{ - Addr: netip.AddrFrom4(ap.Addr().As4()), - Port: ap.Port(), - } - } return Socksaddr{ Addr: ap.Addr(), Port: ap.Port(), @@ -163,17 +162,11 @@ func AddrPortFromNet(netAddr net.Addr) netip.AddrPort { func AddrFromIP(ip net.IP) netip.Addr { addr, _ := netip.AddrFromSlice(ip) - if addr.Is4In6() { - addr = netip.AddrFrom4(addr.As4()) - } return addr } func ParseAddr(s string) netip.Addr { addr, _ := netip.ParseAddr(s) - if addr.Is4In6() { - addr = netip.AddrFrom4(addr.As4()) - } return addr } @@ -187,9 +180,6 @@ func ParseSocksaddr(address string) Socksaddr { func ParseSocksaddrHostPort(host string, port uint16) Socksaddr { netAddr, err := netip.ParseAddr(host) - if netAddr.Is4In6() { - netAddr = netip.AddrFrom4(netAddr.As4()) - } if err != nil { return Socksaddr{ Fqdn: host, @@ -206,9 +196,6 @@ func ParseSocksaddrHostPort(host string, port uint16) Socksaddr { func ParseSocksaddrHostPortStr(host string, portStr string) Socksaddr { port, _ := strconv.Atoi(portStr) netAddr, err := netip.ParseAddr(host) - if netAddr.Is4In6() { - netAddr = netip.AddrFrom4(netAddr.As4()) - } if err != nil { return Socksaddr{ Fqdn: host, diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index 1e7f934..6f7f54b 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -56,7 +56,7 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error { return err } if addr.Addr.IsValid() { - err = rw.WriteBytes(writer, addr.Addr.AsSlice()) + err = rw.WriteBytes(writer, addr.Unwrap().Addr.AsSlice()) } else { err = WriteSocksString(writer, addr.Fqdn) } @@ -129,11 +129,8 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { if err != nil { return Socksaddr{}, E.Cause(err, "read ipv6 address") } - netAddr := netip.AddrFrom16(addr) - if netAddr.Is4In6() { - netAddr = netip.AddrFrom4(netAddr.As4()) - } - return Socksaddr{Addr: netAddr}, nil + + return Socksaddr{Addr: netip.AddrFrom16(addr)}.Unwrap(), nil default: return Socksaddr{}, E.New("unknown address family: ", af) } diff --git a/common/network/multi.go b/common/network/multi.go index b2afc11..2649cd0 100644 --- a/common/network/multi.go +++ b/common/network/multi.go @@ -18,7 +18,7 @@ func DialSerial(ctx context.Context, dialer Dialer, network string, destination var err error var connErrors []error for _, address := range destinationAddresses { - conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) + conn, err = dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port)) if err != nil { connErrors = append(connErrors, err) continue @@ -33,7 +33,7 @@ func ListenSerial(ctx context.Context, dialer Dialer, destination M.Socksaddr, d var err error var connErrors []error for _, address := range destinationAddresses { - conn, err = dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port)) + conn, err = dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port)) if err != nil { connErrors = append(connErrors, err) continue diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 5908a59..81869fb 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -112,7 +112,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent case socks4.CommandConnect: responseAddr := request.Destination if !responseAddr.IsIPv4() { - responseAddr = M.SocksaddrFromAddrPort(netip.IPv4Unspecified(), responseAddr.Port) + responseAddr = M.SocksaddrFrom(netip.IPv4Unspecified(), responseAddr.Port) } err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeGranted,