Add network name constants

This commit is contained in:
世界 2022-07-29 19:48:14 +08:00
parent 6045339c12
commit c922e42771
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 43 additions and 12 deletions

26
common/network/name.go Normal file
View file

@ -0,0 +1,26 @@
package network
import (
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
var ErrUnknownNetwork = E.New("unknown network")
//goland:noinspection GoNameStartsWithPackageName
const (
NetworkTCP = "tcp"
NetworkUDP = "udp"
)
//goland:noinspection GoNameStartsWithPackageName
func NetworkName(network string) string {
if strings.HasPrefix(network, "tcp") {
return NetworkTCP
} else if strings.HasPrefix(network, "udp") {
return NetworkUDP
} else {
return network
}
}

View file

@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"os"
"strings"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -36,11 +35,16 @@ func NewClient(dialer N.Dialer, serverAddr M.Socksaddr, username string, passwor
}
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
network = N.NetworkName(network)
switch network {
case N.NetworkTCP:
case N.NetworkUDP:
return nil, os.ErrInvalid
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
var conn net.Conn
conn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr)
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return nil, err
}

View file

@ -97,9 +97,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, E.New("unsupported network ", network)
}
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
left, right := net.Pipe()

View file

@ -100,16 +100,20 @@ func NewClientFromURL(dialer N.Dialer, rawURL string) (*Client, error) {
}
func (c *Client) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
network = N.NetworkName(network)
var command byte
if strings.HasPrefix(network, "tcp") {
switch network {
case N.NetworkTCP:
command = socks4.CommandConnect
} else {
case N.NetworkUDP:
if c.version != Version5 {
return nil, E.New("socks4: udp unsupported")
}
command = socks5.CommandUDPAssociate
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr)
tcpConn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return nil, err
}
@ -138,7 +142,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
if command == socks5.CommandConnect {
return tcpConn, nil
}
udpConn, err := c.dialer.DialContext(ctx, "udp", response.Bind)
udpConn, err := c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind)
if err != nil {
tcpConn.Close()
return nil, err
@ -149,7 +153,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
}
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := c.DialContext(ctx, "udp", destination)
conn, err := c.DialContext(ctx, N.NetworkUDP, destination)
if err != nil {
return nil, err
}
@ -157,7 +161,7 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
}
func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn, error) {
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.serverAddr)
tcpConn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return nil, err
}