diff --git a/common/network/name.go b/common/network/name.go new file mode 100644 index 0000000..28a745b --- /dev/null +++ b/common/network/name.go @@ -0,0 +1,26 @@ +package network + +import ( + "strings" + + E "github.com/sagernet/sing/common/exceptions" +) + +var ErrUnknownNetwork = E.New("unknown network") + +//goland:noinspection GoNameStartsWithPackageName +const ( + NetworkTCP = "tcp" + NetworkUDP = "udp" +) + +//goland:noinspection GoNameStartsWithPackageName +func NetworkName(network string) string { + if strings.HasPrefix(network, "tcp") { + return NetworkTCP + } else if strings.HasPrefix(network, "udp") { + return NetworkUDP + } else { + return network + } +} diff --git a/protocol/http/client.go b/protocol/http/client.go index cbca5d0..8277a65 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "os" - "strings" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -36,11 +35,16 @@ func NewClient(dialer N.Dialer, serverAddr M.Socksaddr, username string, passwor } func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if !strings.HasPrefix(network, "tcp") { + network = N.NetworkName(network) + switch network { + case N.NetworkTCP: + case N.NetworkUDP: return nil, os.ErrInvalid + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) } var conn net.Conn - conn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr) + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) if err != nil { return nil, err } diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 0a90717..516e3be 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -97,9 +97,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, DialContext: func(context context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, E.New("unsupported network ", network) - } metadata.Destination = M.ParseSocksaddr(address) metadata.Protocol = "http" left, right := net.Pipe() diff --git a/protocol/socks/client.go b/protocol/socks/client.go index 8ac7ab8..45c0d90 100644 --- a/protocol/socks/client.go +++ b/protocol/socks/client.go @@ -100,16 +100,20 @@ func NewClientFromURL(dialer N.Dialer, rawURL string) (*Client, error) { } func (c *Client) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { + network = N.NetworkName(network) var command byte - if strings.HasPrefix(network, "tcp") { + switch network { + case N.NetworkTCP: command = socks4.CommandConnect - } else { + case N.NetworkUDP: if c.version != Version5 { return nil, E.New("socks4: udp unsupported") } command = socks5.CommandUDPAssociate + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) } - tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr) + tcpConn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) if err != nil { return nil, err } @@ -138,7 +142,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock if command == socks5.CommandConnect { return tcpConn, nil } - udpConn, err := c.dialer.DialContext(ctx, "udp", response.Bind) + udpConn, err := c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind) if err != nil { tcpConn.Close() return nil, err @@ -149,7 +153,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock } func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - conn, err := c.DialContext(ctx, "udp", destination) + conn, err := c.DialContext(ctx, N.NetworkUDP, destination) if err != nil { return nil, err } @@ -157,7 +161,7 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net } func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn, error) { - tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr) + tcpConn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) if err != nil { return nil, err }