Refactor socksaddr

This commit is contained in:
世界 2022-05-04 19:12:27 +08:00
parent 9378ae739c
commit b35c53ca8f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
54 changed files with 1191 additions and 666 deletions

View file

@ -45,13 +45,8 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au
if portStr == "" {
portStr = "80"
}
destination, err := M.ParseAddrPort(request.URL.Hostname(), portStr)
if err != nil {
if err != nil {
return err
}
}
_, err = fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
destination := M.ParseSocksaddrHostPort(request.URL.Hostname(), portStr)
_, err := fmt.Fprintf(conn, "HTTP/%d.%d %03d %s\r\n\r\n", request.ProtoMajor, request.ProtoMinor, http.StatusOK, "Connection established")
if err != nil {
return E.Cause(err, "write http response")
}
@ -87,17 +82,11 @@ func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, au
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, E.New("unsupported network ", network)
}
destination, err := M.ParseAddress(address)
if err != nil {
return nil, err
}
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
left, right := net.Pipe()
go func() {
metadata.Destination = destination
metadata.Protocol = "http"
err = handler.NewConnection(ctx, right, metadata)
err := handler.NewConnection(ctx, right, metadata)
if err != nil {
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
}

View file

@ -10,8 +10,9 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
const MethodNone = "none"
@ -30,7 +31,7 @@ func (m *NoneMethod) KeyLength() int {
return 0
}
func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *NoneMethod) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &noneConn{
Conn: conn,
handshake: true,
@ -39,14 +40,14 @@ func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn,
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &noneConn{
Conn: conn,
destination: destination,
}
}
func (m *NoneMethod) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *NoneMethod) DialPacketConn(conn net.Conn) N.PacketConn {
return &nonePacketConn{conn}
}
@ -55,11 +56,11 @@ type noneConn struct {
access sync.Mutex
handshake bool
destination *M.AddrPort
destination M.Socksaddr
}
func (c *noneConn) clientHandshake() error {
err := socks.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(c.Conn, c.destination)
if err != nil {
return err
}
@ -87,7 +88,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
err = socks5.AddressSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return
}
@ -132,19 +133,19 @@ type nonePacketConn struct {
net.Conn
}
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
_, err := buffer.ReadFrom(c)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort M.Socksaddr) error {
defer buffer.Release()
_header := buf.StackNewMax()
header := common.Dup(_header)
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
err := socks5.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {
header.Release()
return err
@ -167,7 +168,7 @@ func NewNoneService(udpTimeout int64, handler Handler) Service {
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
return err
}
@ -176,34 +177,34 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
func (s *NoneService) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
s.udp.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &nonePacketWriter{conn, metadata.Source}
}, buffer, metadata)
return nil
}
type nonePacketWriter struct {
socks.PacketConn
sourceAddr *M.AddrPort
N.PacketConn
sourceAddr M.Socksaddr
}
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination)))
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(ctx, conn, metadata)
}

View file

@ -8,15 +8,15 @@ import (
"net"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type Method interface {
Name() string
KeyLength() int
DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
DialPacketConn(conn net.Conn) socks.PacketConn
DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
DialPacketConn(conn net.Conn) N.PacketConn
}
func Key(password []byte, keySize int) []byte {

View file

@ -7,17 +7,17 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
N "github.com/sagernet/sing/common/network"
)
type Service interface {
M.TCPConnectionHandler
socks.UDPHandler
N.UDPHandler
}
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
@ -34,7 +34,7 @@ type UserContext[U comparable] struct {
type ServerConnError struct {
net.Conn
Source *M.AddrPort
Source M.Socksaddr
Cause error
}
@ -47,8 +47,8 @@ func (e *ServerConnError) Error() string {
}
type ServerPacketError struct {
socks.PacketConn
Source *M.AddrPort
N.PacketConn
Source M.Socksaddr
Cause error
}

View file

@ -12,10 +12,11 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
@ -138,7 +139,7 @@ func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
@ -147,7 +148,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
@ -155,7 +156,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
return &clientPacketConn{m, conn}
}
@ -186,7 +187,7 @@ type clientConn struct {
net.Conn
method *Method
destination *M.AddrPort
destination M.Socksaddr
access sync.Mutex
reader *Reader
@ -209,7 +210,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
bufferedWriter := writer.BufferedWriter(header.Len())
if len(payload) > 0 {
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -219,7 +220,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
return err
}
} else {
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return err
}
@ -325,10 +326,10 @@ type clientPacketConn struct {
net.Conn
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(c.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
@ -339,17 +340,17 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(n)
err = c.DecodePacket(buffer)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) UpstreamReader() io.Reader {

View file

@ -12,11 +12,12 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
)
@ -97,7 +98,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
key := Kdf(s.key, salt, s.keySaltLength)
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return err
}
@ -198,7 +199,7 @@ func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -206,7 +207,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
return err
}
func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
@ -219,7 +220,7 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
s.udpNat.NewPacket(metadata.Source.AddrPort(), func() N.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
return nil
@ -227,14 +228,14 @@ func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
type serverPacketWriter struct {
*Service
socks.PacketConn
source *M.AddrPort
N.PacketConn
source M.Socksaddr
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buffer.ExtendHeader(w.keySaltLength + socks5.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
err := socks5.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}

View file

@ -19,11 +19,12 @@ import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/log"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
@ -163,7 +164,7 @@ func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
@ -172,7 +173,7 @@ func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, err
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Conn: conn,
method: m,
@ -180,7 +181,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
return &clientPacketConn{conn, m, m.newUDPSession()}
}
@ -188,7 +189,7 @@ type clientConn struct {
net.Conn
method *Method
destination *M.AddrPort
destination M.Socksaddr
request sync.Mutex
response sync.Mutex
@ -267,7 +268,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
err := socks5.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
if err != nil {
return E.Cause(err, "write destination")
}
@ -465,7 +466,7 @@ type clientPacketConn struct {
session *udpSession
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if debug.Enabled {
logger.Trace("begin client packet")
}
@ -534,7 +535,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -551,14 +552,16 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
buffer.Extend(c.session.cipher.Overhead())
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
logger.Trace("ended client packet")
if debug.Enabled {
logger.Trace("ended client packet")
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(n)
@ -566,7 +569,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
if c.method.udpCipher != nil {
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return nil, E.Cause(err, "decrypt packet")
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
buffer.Advance(PacketNonceSize)
} else {
@ -577,11 +580,11 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
var sessionId, packetId uint64
err = binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
var remoteCipher cipher.AEAD
@ -596,42 +599,42 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
}
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return nil, E.Cause(err, "decrypt packet")
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if headerType != HeaderTypeServer {
return nil, ErrBadHeaderType
return M.Socksaddr{}, ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return nil, ErrBadTimestamp
return M.Socksaddr{}, ErrBadTimestamp
}
if sessionId == c.session.remoteSessionId {
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
return M.Socksaddr{}, ErrPacketIdNotUnique
}
} else if sessionId == c.session.lastRemoteSessionId {
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
return M.Socksaddr{}, ErrPacketIdNotUnique
}
remoteCipher = c.session.lastRemoteCipher
c.session.lastRemoteSeen = time.Now().Unix()
} else {
if c.session.remoteSessionId != 0 {
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
return nil, ErrTooManyServerSessions
return M.Socksaddr{}, ErrTooManyServerSessions
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastFilter = c.session.filter
@ -648,20 +651,20 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
var clientSessionId uint64
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
if clientSessionId != c.session.sessionId {
return nil, ErrBadClientSessionId
return M.Socksaddr{}, ErrBadClientSessionId
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return nil, E.Cause(err, "read padding length")
return M.Socksaddr{}, E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
return socks.AddressSerializer.ReadAddrPort(buffer)
return socks5.AddressSerializer.ReadAddrPort(buffer)
}
type udpSession struct {

View file

@ -18,12 +18,13 @@ import (
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
@ -132,7 +133,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -268,7 +269,7 @@ func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -276,7 +277,7 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
return err
}
func (s *Service) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
@ -358,14 +359,14 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
metadata.Destination = destination
session.remoteAddr = metadata.Source
s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
s.udpNat.NewPacket(sessionId, func() N.PacketWriter {
return &serverPacketWriter{s, conn, session}
}, buffer, metadata)
return nil
@ -373,11 +374,11 @@ process:
type serverPacketWriter struct {
*Service
socks.PacketConn
N.PacketConn
session *serverUDPSession
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
_header := buf.StackNew()
@ -400,7 +401,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
@ -425,7 +426,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
remoteAddr *M.AddrPort
remoteAddr M.Socksaddr
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD

View file

@ -13,10 +13,11 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
"lukechampine.com/blake3"
)
@ -140,7 +141,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return ErrBadTimestamp
}
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
destination, err := socks5.AddressSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
@ -173,7 +174,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
}, metadata)
}
func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *MultiService[U]) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
err := s.newPacket(conn, buffer, metadata)
if err != nil {
err = &shadowsocks.ServerPacketError{PacketConn: conn, Source: metadata.Source, Cause: err}
@ -181,7 +182,7 @@ func (s *MultiService[U]) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, m
return err
}
func (s *MultiService[U]) newPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
packetHeader := buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
@ -272,7 +273,7 @@ process:
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
goto returnErr
}
@ -284,7 +285,7 @@ process:
userCtx.Context = context.Background()
userCtx.User = user
s.udpNat.NewContextPacket(&userCtx, sessionId, func() socks.PacketWriter {
s.udpNat.NewContextPacket(&userCtx, sessionId, func() N.PacketWriter {
return &serverPacketWriter{s.Service, conn, session}
}, buffer, metadata)
return nil

View file

@ -1,124 +0,0 @@
package socks
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type PacketReader interface {
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
}
type PacketWriter interface {
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
}
type PacketConn interface {
PacketReader
PacketWriter
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
type UDPHandler interface {
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
}
type UDPConnectionHandler interface {
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
}
type PacketConnStub struct{}
func (s *PacketConnStub) RemoteAddr() net.Addr {
return &common.DummyAddr{}
}
func (s *PacketConnStub) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (s *PacketConnStub) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
data.FullReset()
destination, err := conn.ReadPacket(data)
if err != nil {
return err
}
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = dest.WritePacket(buffer, destination)
if err != nil {
return err
}
}
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
data.FullReset()
destination, err := dest.ReadPacket(data)
if err != nil {
return err
}
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = conn.WritePacket(buffer, destination)
if err != nil {
return err
}
}
})
}
func CopyNetPacketConn(ctx context.Context, conn PacketConn, dest net.PacketConn) error {
return CopyPacketConn(ctx, conn, &PacketConnWrapper{dest})
}
type PacketConnWrapper struct {
net.PacketConn
}
func (p *PacketConnWrapper) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
_, addr, err := buffer.ReadPacketFrom(p)
if err != nil {
return nil, err
}
return M.AddrPortFromNetAddr(addr), err
}
func (p *PacketConnWrapper) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
return common.Error(p.WriteTo(buffer.Bytes(), destination.UDPAddr()))
}
func (p *PacketConnWrapper) RemoteAddr() net.Addr {
return &common.DummyAddr{}
}

View file

@ -1,75 +0,0 @@
package socks_test
import (
"net"
"sync"
"testing"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
func TestHandshake(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
wg := new(sync.WaitGroup)
wg.Add(1)
method := socks.AuthTypeUsernamePassword
go func() {
response, err := socks.ClientHandshake(client, socks.Version5, socks.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
if response.ReplyCode != socks.ReplyCodeSuccess {
t.Fatal(response)
}
wg.Done()
}()
authRequest, err := socks.ReadAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method {
t.Fatal("bad methods: ", authRequest.Methods)
}
err = socks.WriteAuthResponse(server, &socks.AuthResponse{
Version: socks.Version5,
Method: method,
})
if err != nil {
t.Fatal(err)
}
usernamePasswordAuthRequest, err := socks.ReadUsernamePasswordAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" {
t.Fatal(authRequest)
}
err = socks.WriteUsernamePasswordAuthResponse(server, &socks.UsernamePasswordAuthResponse{
Status: socks.UsernamePasswordStatusSuccess,
})
if err != nil {
t.Fatal(err)
}
request, err := socks.ReadRequest(server)
if err != nil {
t.Fatal(err)
}
if request.Version != socks.Version5 || request.Command != socks.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks.WriteResponse(server, &socks.Response{
Version: socks.Version5,
ReplyCode: socks.ReplyCodeSuccess,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"net"
@ -12,10 +12,10 @@ type AssociateConn struct {
net.Conn
conn net.Conn
addr net.Addr
dest *M.AddrPort
dest M.Socksaddr
}
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) *AssociateConn {
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination M.Socksaddr) *AssociateConn {
return &AssociateConn{
Conn: packetConn,
conn: conn,
@ -46,7 +46,7 @@ func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -80,17 +80,17 @@ func (c *AssociateConn) Write(b []byte) (n int, err error) {
return
}
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, err := buffer.ReadFrom(c.conn)
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
@ -102,10 +102,10 @@ type AssociatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
dest *M.AddrPort
dest M.Socksaddr
}
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination M.Socksaddr) *AssociatePacketConn {
return &AssociatePacketConn{
PacketConn: packetConn,
conn: conn,
@ -137,7 +137,7 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
}
@ -171,10 +171,10 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
return
}
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, err
return M.Socksaddr{}, err
}
c.addr = addr
buffer.Truncate(n)
@ -183,7 +183,7 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
return dest, err
}
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"strconv"

View file

@ -1,4 +1,4 @@
package socks
package socks5
import "fmt"

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"io"
@ -8,7 +8,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination *M.AddrPort, username string, password string) (*Response, error) {
func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination M.Socksaddr, username string, password string) (*Response, error) {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired
@ -56,7 +56,7 @@ func ClientHandshake(conn io.ReadWriter, version byte, command byte, destination
return ReadResponse(conn)
}
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination *M.AddrPort, username string, password string) error {
func ClientFastHandshake(writer io.Writer, version byte, command byte, destination M.Socksaddr, username string, password string) error {
var method byte
if common.IsBlank(username) {
method = AuthTypeNotRequired

View file

@ -1,4 +1,4 @@
package socks
package socks5
import (
"context"
@ -10,12 +10,13 @@ import (
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/transport/tcp"
)
type Handler interface {
tcp.Handler
UDPConnectionHandler
N.UDPConnectionHandler
}
type Listener struct {
@ -36,7 +37,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
return HandleConnection(ctx, conn, l.authenticator, M.AddrFromNetAddr(conn.LocalAddr()), l.handler, metadata)
}
func (l *Listener) Start() error {
@ -117,7 +118,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(conn.LocalAddr()),
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")
@ -138,7 +139,7 @@ func handleConnection(authRequest *AuthRequest, ctx context.Context, conn net.Co
err = WriteResponse(conn, &Response{
Version: request.Version,
ReplyCode: ReplyCodeSuccess,
Bind: M.AddrPortFromNetAddr(udpConn.LocalAddr()),
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return E.Cause(err, "write socks response")

View file

@ -0,0 +1 @@
package socks5

View file

@ -1,9 +1,9 @@
package socks
package socks5
import (
"bytes"
"io"
"net"
"net/netip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@ -203,7 +203,7 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthRe
type Request struct {
Version byte
Command byte
Destination *M.AddrPort
Destination M.Socksaddr
}
func WriteRequest(writer io.Writer, request *Request) error {
@ -262,7 +262,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
type Response struct {
Version byte
ReplyCode ReplyCode
Bind *M.AddrPort
Bind M.Socksaddr
}
func WriteResponse(writer io.Writer, response *Response) error {
@ -278,8 +278,10 @@ func WriteResponse(writer io.Writer, response *Response) error {
if err != nil {
return err
}
if response.Bind == nil {
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
if !response.Bind.IsValid() {
return AddressSerializer.WriteAddrPort(writer, M.Socksaddr{
Addr: netip.IPv4Unspecified(),
})
}
return AddressSerializer.WriteAddrPort(writer, response.Bind)
}
@ -320,7 +322,7 @@ func ReadResponse(reader io.Reader) (*Response, error) {
type AssociatePacket struct {
Fragment byte
Destination *M.AddrPort
Destination M.Socksaddr
Data []byte
}

View file

@ -0,0 +1,75 @@
package socks5_test
import (
"net"
"sync"
"testing"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks5"
)
func TestHandshake(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
wg := new(sync.WaitGroup)
wg.Add(1)
method := socks5.AuthTypeUsernamePassword
go func() {
response, err := socks5.ClientHandshake(client, socks5.Version5, socks5.CommandConnect, M.AddrPortFrom(M.AddrFromFqdn("test"), 80), "user", "pswd")
if err != nil {
t.Fatal(err)
}
if response.ReplyCode != socks5.ReplyCodeSuccess {
t.Fatal(response)
}
wg.Done()
}()
authRequest, err := socks5.ReadAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if len(authRequest.Methods) != 1 || authRequest.Methods[0] != method {
t.Fatal("bad methods: ", authRequest.Methods)
}
err = socks5.WriteAuthResponse(server, &socks5.AuthResponse{
Version: socks5.Version5,
Method: method,
})
if err != nil {
t.Fatal(err)
}
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(server)
if err != nil {
t.Fatal(err)
}
if usernamePasswordAuthRequest.Username != "user" || usernamePasswordAuthRequest.Password != "pswd" {
t.Fatal(authRequest)
}
err = socks5.WriteUsernamePasswordAuthResponse(server, &socks5.UsernamePasswordAuthResponse{
Status: socks5.UsernamePasswordStatusSuccess,
})
if err != nil {
t.Fatal(err)
}
request, err := socks5.ReadRequest(server)
if err != nil {
t.Fatal(err)
}
if request.Version != socks5.Version5 || request.Command != socks5.CommandConnect || request.Destination.Addr.Fqdn() != "test" || request.Destination.Port != 80 {
t.Fatal(request)
}
err = socks5.WriteResponse(server, &socks5.Response{
Version: socks5.Version5,
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0),
})
if err != nil {
t.Fatal(err)
}
wg.Wait()
}

View file

@ -12,7 +12,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
const (
@ -26,11 +26,11 @@ var CRLF = []byte{'\r', '\n'}
type ClientConn struct {
net.Conn
key [KeyLength]byte
destination *M.AddrPort
destination M.Socksaddr
headerWritten bool
}
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort) *ClientConn {
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
return &ClientConn{
Conn: conn,
key: key,
@ -75,11 +75,11 @@ func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
}
}
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return ReadPacket(c.Conn, buffer)
}
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.headerWritten {
return ClientHandshakePacket(c.Conn, c.key, destination, buffer)
}
@ -98,7 +98,7 @@ func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = c.WritePacket(buf.With(p), M.AddrPortFromNetAddr(addr))
err = c.WritePacket(buf.With(p), M.SocksaddrFromNet(addr))
if err == nil {
n = len(p)
}
@ -113,7 +113,7 @@ func Key(password string) [KeyLength]byte {
return key
}
func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination *M.AddrPort, payload []byte) error {
func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
_, err := conn.Write(key[:])
if err != nil {
return err
@ -126,7 +126,7 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
if err != nil {
return err
}
err = socks.AddressSerializer.WriteAddrPort(conn, destination)
err = socks5.AddressSerializer.WriteAddrPort(conn, destination)
if err != nil {
return err
}
@ -143,8 +143,8 @@ func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destin
return nil
}
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload []byte) error {
headerLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
headerLen := KeyLength + socks5.AddressSerializer.AddrPortLen(destination) + 5
var header *buf.Buffer
var writeHeader bool
if len(payload) > 0 && headerLen+len(payload) < 65535 {
@ -158,7 +158,7 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandTCP))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must1(header.Write(payload))
@ -176,8 +176,8 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort
return nil
}
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload *buf.Buffer) error {
headerLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
headerLen := KeyLength + 2*socks5.AddressSerializer.AddrPortLen(destination) + 9
payloadLen := payload.Len()
var header *buf.Buffer
var writeHeader bool
@ -191,9 +191,9 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad
common.Must1(header.Write(key[:]))
common.Must1(header.Write(CRLF))
common.Must(header.WriteByte(CommandUDP))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must1(header.Write(CRLF))
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
common.Must1(header.Write(CRLF))
@ -211,33 +211,33 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.Ad
return nil
}
func ReadPacket(conn net.Conn, buffer *buf.Buffer) (*M.AddrPort, error) {
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
return nil, E.Cause(err, "read destination")
return M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(conn, binary.BigEndian, &length)
if err != nil {
return nil, E.Cause(err, "read chunk length")
return M.Socksaddr{}, E.Cause(err, "read chunk length")
}
if buffer.FreeLen() < int(length) {
return nil, io.ErrShortBuffer
return M.Socksaddr{}, io.ErrShortBuffer
}
err = rw.SkipN(conn, 2)
if err != nil {
return nil, E.Cause(err, "skip crlf")
return M.Socksaddr{}, E.Cause(err, "skip crlf")
}
_, err = buffer.ReadFullFrom(conn, int(length))
return destination, err
}
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) error {
headerOverload := socks.AddressSerializer.AddrPortLen(destination) + 4
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
headerOverload := socks5.AddressSerializer.AddrPortLen(destination) + 4
var header *buf.Buffer
var writeHeader bool
bufferLen := buffer.Len()
@ -248,7 +248,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err
_buffer := buf.Make(headerOverload)
header = buf.With(common.Dup(_buffer))
}
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
common.Must1(header.Write(CRLF))
if writeHeader {

View file

@ -10,13 +10,14 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/protocol/socks5"
)
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
N.UDPConnectionHandler
}
type Context[K comparable] struct {
@ -115,7 +116,7 @@ process:
goto returnErr
}
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
destination, err := socks5.AddressSerializer.ReadAddrPort(conn)
if err != nil {
err = E.Cause(err, "read destination")
goto returnErr
@ -141,11 +142,11 @@ type PacketConn struct {
net.Conn
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return ReadPacket(c.Conn, buffer)
}
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WritePacket(c.Conn, buffer, destination)
}