diff --git a/common/bufio/conn.go b/common/bufio/conn.go index ffa83ea..fa5d157 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -354,7 +354,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() - if destination.Family().IsFqdn() { + if destination.IsFqdn() { udpAddr, err := net.ResolveUDPAddr("udp", destination.String()) if err != nil { return err diff --git a/common/metadata/addr.go b/common/metadata/addr.go index f1a2b57..eb17a95 100644 --- a/common/metadata/addr.go +++ b/common/metadata/addr.go @@ -4,12 +4,13 @@ import ( "net" "net/netip" "strconv" + "unsafe" ) type Socksaddr struct { Addr netip.Addr - Fqdn string Port uint16 + Fqdn string } func (ap Socksaddr) Network() string { @@ -20,25 +21,22 @@ func (ap Socksaddr) IsIP() bool { return ap.Addr.IsValid() } +func (ap Socksaddr) IsIPv4() bool { + return ap.Addr.Is4() +} + +func (ap Socksaddr) IsIPv6() bool { + return ap.Addr.Is6() +} + func (ap Socksaddr) IsFqdn() bool { - return !ap.IsIP() + return !ap.Addr.IsValid() } func (ap Socksaddr) IsValid() bool { return ap.Addr.IsValid() || ap.Fqdn != "" } -func (ap Socksaddr) Family() Family { - if ap.Addr.IsValid() { - if ap.Addr.Is4() || ap.Addr.Is4In6() { - return AddressFamilyIPv4 - } else { - return AddressFamilyIPv6 - } - } - return AddressFamilyFqdn -} - func (ap Socksaddr) AddrString() string { if ap.Addr.IsValid() { return ap.Addr.String() @@ -68,7 +66,7 @@ func (ap Socksaddr) UDPAddr() *net.UDPAddr { } func (ap Socksaddr) AddrPort() netip.AddrPort { - return netip.AddrPortFrom(ap.Addr, ap.Port) + return *(*netip.AddrPort)(unsafe.Pointer(&ap)) } func (ap Socksaddr) String() string { diff --git a/common/metadata/family.go b/common/metadata/family.go index 80011ee..c6551c2 100644 --- a/common/metadata/family.go +++ b/common/metadata/family.go @@ -1,27 +1,9 @@ package metadata -type Family byte +type Family = byte const ( - AddressFamilyIPv4 Family = iota - AddressFamilyIPv6 - AddressFamilyFqdn + AddressFamilyIPv4 Family = 0x01 + AddressFamilyIPv6 Family = 0x04 + AddressFamilyFqdn Family = 0x03 ) - -func (af Family) IsIPv4() bool { - return af == AddressFamilyIPv4 -} - -func (af Family) IsIPv6() bool { - return af == AddressFamilyIPv6 -} - -func (af Family) IsIP() bool { - return af != AddressFamilyFqdn -} - -func (af Family) IsFqdn() bool { - return af == AddressFamilyFqdn -} - -type FamilyParser func(byte) byte diff --git a/common/metadata/network.go b/common/metadata/network.go index df5f319..c171f7c 100644 --- a/common/metadata/network.go +++ b/common/metadata/network.go @@ -3,7 +3,7 @@ package metadata import "net/netip" func NetworkFromNetAddr(network string, addr netip.Addr) string { - if addr.Is4() && (addr.IsUnspecified() || addr.IsGlobalUnicast() || addr.IsLinkLocalUnicast()) { + if addr == netip.IPv4Unspecified() { return network + "4" } return network diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index 4c3dc37..1e7f934 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -43,7 +43,15 @@ func NewSerializer(options ...SerializerOption) *Serializer { } func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error { - err := rw.WriteByte(writer, s.familyByteMap[addr.Family()]) + var family Family + if addr.IsIPv4() { + family = AddressFamilyIPv4 + } else if addr.IsIPv6() { + family = AddressFamilyIPv6 + } else { + family = AddressFamilyFqdn + } + err := rw.WriteByte(writer, family) if err != nil { return err } @@ -56,12 +64,11 @@ func (s *Serializer) WriteAddress(writer io.Writer, addr Socksaddr) error { } func (s *Serializer) AddressLen(addr Socksaddr) int { - switch addr.Family() { - case AddressFamilyIPv4: + if addr.IsIPv4() { return 5 - case AddressFamilyIPv6: + } else if addr.IsIPv6() { return 17 - default: + } else { return 2 + len(addr.Fqdn) } } diff --git a/common/uot/server.go b/common/uot/server.go index f5ff4e8..efd0c6e 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -54,7 +54,7 @@ func (c *ServerConn) loopInput() { if err != nil { break } - if destination.Family().IsFqdn() { + if destination.IsFqdn() { addr, err := net.ResolveUDPAddr("udp", destination.String()) if err != nil { continue diff --git a/protocol/socks/client.go b/protocol/socks/client.go index 82b6a10..8ac7ab8 100644 --- a/protocol/socks/client.go +++ b/protocol/socks/client.go @@ -113,7 +113,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock if err != nil { return nil, err } - if c.version == Version4 && address.Family().IsFqdn() { + if c.version == Version4 && address.IsFqdn() { tcpAddr, err := net.ResolveTCPAddr(network, address.String()) if err != nil { tcpConn.Close() diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 6cc891d..5908a59 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -111,7 +111,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent switch request.Command { case socks4.CommandConnect: responseAddr := request.Destination - if !responseAddr.Family().IsIPv4() { + if !responseAddr.IsIPv4() { responseAddr = M.SocksaddrFromAddrPort(netip.IPv4Unspecified(), responseAddr.Port) } err = socks4.WriteResponse(conn, socks4.Response{ diff --git a/protocol/socks/socks4/protocol.go b/protocol/socks/socks4/protocol.go index 6d7fc8a..23d507b 100644 --- a/protocol/socks/socks4/protocol.go +++ b/protocol/socks/socks4/protocol.go @@ -81,7 +81,7 @@ func WriteRequest(writer io.Writer, request Request) error { if err != nil { return err } - if request.Destination.Family().IsIPv4() { + if request.Destination.IsIPv4() { dstIP := request.Destination.Addr.As4() _, err = writer.Write(dstIP[:]) if err != nil {