mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Add shadowsocks service
This commit is contained in:
parent
5be6eb2d64
commit
df1e1cfafd
21 changed files with 1028 additions and 265 deletions
|
@ -4,32 +4,35 @@ import (
|
|||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type MultiUserService interface {
|
||||
Service
|
||||
AddUser(key []byte)
|
||||
RemoveUser(key []byte)
|
||||
socks.UDPHandler
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
M.TCPConnectionHandler
|
||||
socks.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type NoneService struct {
|
||||
handler Handler
|
||||
udp *udpnat.Service[string]
|
||||
}
|
||||
|
||||
func NewNoneService(handler Handler) Service {
|
||||
return &NoneService{
|
||||
s := &NoneService{
|
||||
handler: handler,
|
||||
}
|
||||
s.udp = udpnat.New[string](s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
|
@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
|
|||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{conn, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
socks.PacketConn
|
||||
sourceAddr *M.AddrPort
|
||||
}
|
||||
|
||||
func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
|
||||
}
|
||||
|
||||
func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *NoneService) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
|||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &clientPacketConn{conn, m}
|
||||
return &clientPacketConn{m, conn}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
return nil
|
||||
}
|
||||
|
@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
method *Method
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
|
||||
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(buffer)
|
||||
err = c.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.method.DecodePacket(buffer)
|
||||
err = c.DecodePacket(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
@ -25,6 +26,7 @@ type Service struct {
|
|||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
udp *udpnat.Service[string]
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
|
@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
|
|||
secureRNG: secureRNG,
|
||||
handler: handler,
|
||||
}
|
||||
s.udp = udpnat.New[string](s)
|
||||
if replayFilter {
|
||||
s.replayFilter = replay.NewBloomRing()
|
||||
}
|
||||
|
@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
if buffer.Len() < s.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
|
||||
c := s.constructor(common.Dup(key))
|
||||
/*data := buf.New()
|
||||
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.Truncate(len(packet))
|
||||
metadata.Protocol = "shadowsocks"
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, metadata.Source}
|
||||
}, data, metadata)*/
|
||||
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(s.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
metadata.Protocol = "shadowsocks"
|
||||
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
*Service
|
||||
socks.PacketConn
|
||||
source *M.AddrPort
|
||||
}
|
||||
|
||||
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
|
||||
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
|
||||
c := w.constructor(common.Dup(key))
|
||||
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
return w.PacketConn.WritePacket(buffer, w.source)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *Service) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
|||
}
|
||||
|
||||
func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte {
|
||||
sessionKey := make([]byte, 2*KeySaltSize)
|
||||
sessionKey := buf.Make(len(psk) + len(salt))
|
||||
copy(sessionKey, psk)
|
||||
copy(sessionKey[KeySaltSize:], salt)
|
||||
copy(sessionKey[len(psk):], salt)
|
||||
outKey := buf.Make(keyLength)
|
||||
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
|
||||
return outKey
|
||||
|
@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
|
|||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.method.udpCipher != nil {
|
||||
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.method.udpCipher.Overhead())
|
||||
|
@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 {
|
|||
}
|
||||
|
||||
func (m *Method) newUDPSession() *udpSession {
|
||||
session := &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
}
|
||||
session := &udpSession{}
|
||||
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
|
||||
session.packetId--
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||
|
|
|
@ -2,32 +2,43 @@ package shadowaead_2022
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/gsync"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
blockConstructor func(key []byte) cipher.Block
|
||||
udpCipher cipher.AEAD
|
||||
udpBlockCipher cipher.Block
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
udpNat *udpnat.Service[uint64]
|
||||
sessions gsync.Map[uint64, *serverUDPSession]
|
||||
}
|
||||
|
||||
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
|
@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
|
|||
case "2022-blake3-aes-128-gcm":
|
||||
s.keyLength = 16
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
s.blockConstructor = newAES
|
||||
s.udpBlockCipher = newAES(s.psk)
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
s.keyLength = 32
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
s.blockConstructor = newAES
|
||||
s.udpBlockCipher = newAES(s.psk)
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
s.keyLength = 32
|
||||
s.constructor = newChacha20Poly1305
|
||||
// m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||
s.udpCipher = newXChacha20Poly1305(s.psk)
|
||||
}
|
||||
|
||||
s.udpNat = udpnat.New[uint64](s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
var packetHeader []byte
|
||||
if s.udpCipher != nil {
|
||||
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decrypt packet header")
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
} else {
|
||||
packetHeader = buffer.To(aes.BlockSize)
|
||||
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||
}
|
||||
|
||||
var sessionId, packetId uint64
|
||||
err := binary.Read(buffer, binary.BigEndian, &sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
|
||||
if !loaded {
|
||||
session.remoteSessionId = sessionId
|
||||
if packetHeader != nil {
|
||||
key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength)
|
||||
session.remoteCipher = s.constructor(common.Dup(key))
|
||||
}
|
||||
}
|
||||
|
||||
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return ErrPacketIdNotUnique
|
||||
}
|
||||
|
||||
if packetHeader != nil {
|
||||
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return E.Cause(err, "decrypt packet")
|
||||
}
|
||||
}
|
||||
|
||||
var headerType byte
|
||||
headerType, err = buffer.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
return ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
var paddingLength uint16
|
||||
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read padding length")
|
||||
}
|
||||
buffer.Advance(int(paddingLength))
|
||||
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Destination = destination
|
||||
|
||||
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, session, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
*Service
|
||||
socks.PacketConn
|
||||
session *serverUDPSession
|
||||
source *M.AddrPort
|
||||
}
|
||||
|
||||
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
var dataIndex int
|
||||
if w.udpCipher != nil {
|
||||
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
|
||||
dataIndex = buffer.Len()
|
||||
} else {
|
||||
dataIndex = aes.BlockSize
|
||||
}
|
||||
|
||||
common.Must(
|
||||
binary.Write(header, binary.BigEndian, w.session.sessionId),
|
||||
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
|
||||
header.WriteByte(HeaderTypeServer),
|
||||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
|
||||
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
|
||||
)
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = header.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.udpCipher != nil {
|
||||
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
|
||||
header.Extend(w.udpCipher.Overhead())
|
||||
} else {
|
||||
packetHeader := header.To(aes.BlockSize)
|
||||
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
|
||||
header.Extend(w.session.cipher.Overhead())
|
||||
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
return w.PacketConn.WritePacket(header, w.source)
|
||||
}
|
||||
|
||||
type serverUDPSession struct {
|
||||
sessionId uint64
|
||||
remoteSessionId uint64
|
||||
packetId uint64
|
||||
cipher cipher.AEAD
|
||||
remoteCipher cipher.AEAD
|
||||
filter wgReplay.Filter
|
||||
}
|
||||
|
||||
func (s *serverUDPSession) nextPacketId() uint64 {
|
||||
return atomic.AddUint64(&s.packetId, 1)
|
||||
}
|
||||
|
||||
func (m *Service) newUDPSession() *serverUDPSession {
|
||||
session := &serverUDPSession{}
|
||||
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
|
||||
session.packetId--
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
|
||||
session.cipher = m.constructor(common.Dup(key))
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
|
||||
return s.handler.NewPacketConnection(conn, metadata)
|
||||
}
|
||||
|
||||
func (s *Service) HandleError(err error) {
|
||||
s.handler.HandleError(err)
|
||||
}
|
||||
|
|
|
@ -8,12 +8,21 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
type PacketReader interface {
|
||||
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
|
||||
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
|
||||
}
|
||||
|
||||
type PacketWriter interface {
|
||||
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
|
||||
}
|
||||
|
||||
type PacketConn interface {
|
||||
PacketReader
|
||||
PacketWriter
|
||||
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
|
@ -23,6 +32,10 @@ type PacketConn interface {
|
|||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type UDPHandler interface {
|
||||
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||
}
|
||||
|
||||
type UDPConnectionHandler interface {
|
||||
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
|
||||
}
|
||||
|
@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
|||
|
||||
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
|
@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, func() error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
|
@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
|
||||
err = conn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
|
|||
})
|
||||
}
|
||||
|
||||
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
|
||||
for {
|
||||
buffer := buf.New()
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
defer rw.CloseRead(conn)
|
||||
defer rw.CloseWrite(dest)
|
||||
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
buffer.FullReset()
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
size := buffer.Len()
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}, func() error {
|
||||
defer rw.CloseRead(dest)
|
||||
defer rw.CloseWrite(conn)
|
||||
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
data.FullReset()
|
||||
n, addr, err := dest.ReadFrom(data.FreeBytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Resize(buf.ReversedHeader, n)
|
||||
err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if onAction != nil {
|
||||
onAction(destination, size)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type associatePacketConn struct {
|
||||
net.PacketConn
|
||||
type AssociateConn struct {
|
||||
net.Conn
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
dest *M.AddrPort
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
|
||||
return &associatePacketConn{
|
||||
PacketConn: packetConn,
|
||||
conn: conn,
|
||||
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn {
|
||||
return &AssociateConn{
|
||||
Conn: packetConn,
|
||||
conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) RemoteAddr() net.Addr {
|
||||
func (c *AssociateConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.Conn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reader := buf.As(p[3:n])
|
||||
destination, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
n = copy(p, reader.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) Write(b []byte) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := buffer.ReadFrom(c.conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(int(n))
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, destination))
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
type AssociatePacketConn struct {
|
||||
net.PacketConn
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
dest *M.AddrPort
|
||||
}
|
||||
|
||||
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
|
||||
return &AssociatePacketConn{
|
||||
PacketConn: packetConn,
|
||||
conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reader := buf.As(p[3:n])
|
||||
destination, err := AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
n = copy(p, reader.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
|
||||
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must(buffer.WriteZeroN(3))
|
||||
|
||||
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
|
|||
c.addr = addr
|
||||
buffer.Truncate(n)
|
||||
buffer.Advance(3)
|
||||
return AddressSerializer.ReadAddrPort(buffer)
|
||||
dest, err := AddressSerializer.ReadAddrPort(buffer)
|
||||
return dest, err
|
||||
}
|
||||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, destination))
|
||||
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
|
|||
}
|
||||
|
||||
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata)
|
||||
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
|
@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut
|
|||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
metadata.Protocol = "socks"
|
||||
metadata.Destination = request.Destination
|
||||
go func() {
|
||||
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata)
|
||||
err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
|
||||
if err != nil {
|
||||
handler.HandleError(err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue