Add shadowsocks service

This commit is contained in:
世界 2022-04-29 12:06:10 +08:00
parent 5be6eb2d64
commit df1e1cfafd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
21 changed files with 1028 additions and 265 deletions

View file

@ -4,32 +4,35 @@ import (
"context"
"net"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/socks"
)
type Service interface {
M.TCPConnectionHandler
}
type MultiUserService interface {
Service
AddUser(key []byte)
RemoveUser(key []byte)
socks.UDPHandler
}
type Handler interface {
M.TCPConnectionHandler
socks.UDPConnectionHandler
E.Handler
}
type NoneService struct {
handler Handler
udp *udpnat.Service[string]
}
func NewNoneService(handler Handler) Service {
return &NoneService{
s := &NoneService{
handler: handler,
}
s.udp = udpnat.New[string](s)
return s
}
func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
@ -41,3 +44,37 @@ func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata
metadata.Destination = destination
return s.handler.NewConnection(ctx, conn, metadata)
}
func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
socks.PacketConn
sourceAddr *M.AddrPort
}
func (s *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buf.With(buffer.ExtendHeader(socks.AddressSerializer.AddrPortLen(destination)))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
return s.PacketConn.WritePacket(buffer, s.sourceAddr)
}
func (s *NoneService) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *NoneService) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -156,13 +156,13 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &clientPacketConn{conn, m}
return &clientPacketConn{m, conn}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead())
return nil
}
@ -299,20 +299,18 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
}
type clientPacketConn struct {
*Method
net.Conn
method *Method
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
_header := buf.StackNew()
header := common.Dup(_header)
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
header := buffer.ExtendHeader(c.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(c.secureRNG, header[:c.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
if err != nil {
return err
}
buffer = buffer.WriteBufferAtFirst(header)
err = c.method.EncodePacket(buffer)
err = c.EncodePacket(buffer)
if err != nil {
return err
}
@ -325,7 +323,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
return nil, err
}
buffer.Truncate(n)
err = c.method.DecodePacket(buffer)
err = c.DecodePacket(buffer)
if err != nil {
return nil, err
}

View file

@ -13,6 +13,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305"
@ -25,6 +26,7 @@ type Service struct {
key []byte
secureRNG io.Reader
replayFilter replay.Filter
udp *udpnat.Service[string]
handler shadowsocks.Handler
}
@ -34,6 +36,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
secureRNG: secureRNG,
handler: handler,
}
s.udp = udpnat.New[string](s)
if replayFilter {
s.replayFilter = replay.NewBloomRing()
}
@ -163,3 +166,59 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
/*data := buf.New()
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
data.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, data, metadata)*/
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
header := buffer.ExtendHeader(w.keySaltLength + socks.AddressSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(w.secureRNG, header[:w.keySaltLength]))
err := socks.AddressSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
return err
}
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key))
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(c.Overhead())
return w.PacketConn.WritePacket(buffer, w.source)
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -109,9 +109,9 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
}
func Blake3DeriveKey(psk, salt []byte, keyLength int) []byte {
sessionKey := make([]byte, 2*KeySaltSize)
sessionKey := buf.Make(len(psk) + len(salt))
copy(sessionKey, psk)
copy(sessionKey[KeySaltSize:], salt)
copy(sessionKey[len(psk):], salt)
outKey := buf.Make(keyLength)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
return outKey
@ -434,6 +434,9 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
return err
}
buffer = buffer.WriteBufferAtFirst(header)
if err != nil {
return err
}
if c.method.udpCipher != nil {
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
buffer.Extend(c.method.udpCipher.Overhead())
@ -574,9 +577,9 @@ func (s *udpSession) nextPacketId() uint64 {
}
func (m *Method) newUDPSession() *udpSession {
session := &udpSession{
sessionId: rand.Uint64(),
}
session := &udpSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)

View file

@ -2,32 +2,43 @@ package shadowaead_2022
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
wgReplay "golang.zx2c4.com/wireguard/replay"
)
type Service struct {
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
name string
secureRNG io.Reader
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
udpNat *udpnat.Service[uint64]
sessions gsync.Map[uint64, *serverUDPSession]
}
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
@ -47,18 +58,20 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
case "2022-blake3-aes-128-gcm":
s.keyLength = 16
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-aes-256-gcm":
s.keyLength = 32
s.constructor = newAESGCM
// m.blockConstructor = newAES
// m.udpBlockCipher = newAES(m.psk)
s.blockConstructor = newAES
s.udpBlockCipher = newAES(s.psk)
case "2022-blake3-chacha20-poly1305":
s.keyLength = 32
s.constructor = newChacha20Poly1305
// m.udpCipher = newXChacha20Poly1305(m.psk)
s.udpCipher = newXChacha20Poly1305(s.psk)
}
s.udpNat = udpnat.New[uint64](s)
return s, nil
}
@ -194,3 +207,169 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {
_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return E.Cause(err, "decrypt packet header")
}
buffer.Advance(PacketNonceSize)
} else {
packetHeader = buffer.To(aes.BlockSize)
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
}
var sessionId, packetId uint64
err := binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return err
}
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return err
}
session, loaded := s.sessions.LoadOrStore(sessionId, s.newUDPSession)
if !loaded {
session.remoteSessionId = sessionId
if packetHeader != nil {
key := Blake3DeriveKey(s.psk, packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key))
}
}
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
return ErrPacketIdNotUnique
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return E.Cause(err, "decrypt packet")
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return err
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return ErrBadTimestamp
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
}
metadata.Destination = destination
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session, metadata.Source}
}, buffer, metadata)
}
type serverPacketWriter struct {
*Service
socks.PacketConn
session *serverUDPSession
source *M.AddrPort
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
var dataIndex int
if w.udpCipher != nil {
common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize))
dataIndex = buffer.Len()
} else {
dataIndex = aes.BlockSize
}
common.Must(
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
_, err = header.Write(buffer.Bytes())
if err != nil {
return err
}
if w.udpCipher != nil {
w.udpCipher.Seal(header.Index(dataIndex), header.To(dataIndex), header.From(dataIndex), nil)
header.Extend(w.udpCipher.Overhead())
} else {
packetHeader := header.To(aes.BlockSize)
w.session.cipher.Seal(header.Index(dataIndex), packetHeader[4:16], header.From(dataIndex), nil)
header.Extend(w.session.cipher.Overhead())
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
}
return w.PacketConn.WritePacket(header, w.source)
}
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wgReplay.Filter
}
func (s *serverUDPSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func (m *Service) newUDPSession() *serverUDPSession {
session := &serverUDPSession{}
common.Must(binary.Read(m.secureRNG, binary.BigEndian, &session.sessionId))
session.packetId--
if m.udpCipher == nil {
sessionId := make([]byte, 8)
binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk, sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key))
}
return session
}
func (s *Service) NewPacketConnection(conn socks.PacketConn, metadata M.Metadata) error {
return s.handler.NewPacketConnection(conn, metadata)
}
func (s *Service) HandleError(err error) {
s.handler.HandleError(err)
}

View file

@ -8,12 +8,21 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type PacketConn interface {
type PacketReader interface {
ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error)
WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error
}
type PacketWriter interface {
WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error
}
type PacketConn interface {
PacketReader
PacketWriter
Close() error
LocalAddr() net.Addr
@ -23,6 +32,10 @@ type PacketConn interface {
SetWriteDeadline(t time.Time) error
}
type UDPHandler interface {
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
}
type UDPConnectionHandler interface {
NewPacketConnection(conn PacketConn, metadata M.Metadata) error
}
@ -47,6 +60,8 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -56,13 +71,15 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil {
return err
}
buffer.Truncate(data.Len())
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = dest.WritePacket(buffer, destination)
if err != nil {
return err
}
}
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNewMax()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
@ -72,7 +89,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
if err != nil {
return err
}
buffer.Truncate(data.Len())
buffer.Resize(buf.ReversedHeader+data.Start(), data.Len())
err = conn.WritePacket(buffer, destination)
if err != nil {
return err
@ -81,44 +98,211 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
})
}
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
for {
buffer := buf.New()
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
func CopyNetPacketConn(ctx context.Context, dest net.PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
buffer.FullReset()
destination, err := conn.ReadPacket(buffer)
if err != nil {
return err
}
_, err = dest.WriteTo(buffer.Bytes(), destination.UDPAddr())
if err != nil {
return err
}
}
size := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return err
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for {
data.FullReset()
n, addr, err := dest.ReadFrom(data.FreeBytes())
if err != nil {
return err
}
buffer.Resize(buf.ReversedHeader, n)
err = conn.WritePacket(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return err
}
}
if onAction != nil {
onAction(destination, size)
}
}
})
}
type associatePacketConn struct {
net.PacketConn
type AssociateConn struct {
net.Conn
conn net.Conn
addr net.Addr
dest *M.AddrPort
}
func NewPacketConn(conn net.Conn, packetConn net.PacketConn) PacketConn {
return &associatePacketConn{
PacketConn: packetConn,
conn: conn,
func NewAssociateConn(conn net.Conn, packetConn net.Conn, destination *M.AddrPort) net.PacketConn {
return &AssociateConn{
Conn: packetConn,
conn: conn,
dest: destination,
}
}
func (c *associatePacketConn) RemoteAddr() net.Addr {
func (c *AssociateConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Conn.Read(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociateConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.Conn.Write(buffer.Bytes())
return
}
func (c *AssociateConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := buffer.ReadFrom(c.conn)
if err != nil {
return nil, err
}
buffer.Truncate(int(n))
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
}
func (c *AssociateConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, destination))
return common.Error(c.Conn.Write(buffer.Bytes()))
}
type AssociatePacketConn struct {
net.PacketConn
conn net.Conn
addr net.Addr
dest *M.AddrPort
}
func NewAssociatePacketConn(conn net.Conn, packetConn net.PacketConn, destination *M.AddrPort) *AssociatePacketConn {
return &AssociatePacketConn{
PacketConn: packetConn,
conn: conn,
dest: destination,
}
}
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil {
return
}
reader := buf.As(p[3:n])
destination, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return
}
addr = destination.UDPAddr()
n = copy(p, reader.Bytes())
return
}
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.AddrPortFromNetAddr(addr))
if err != nil {
return
}
_, err = buffer.Write(p)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest)
if err != nil {
return
}
_, err = buffer.Write(b)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buffer.Bytes(), c.addr)
return
}
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, addr, err := c.PacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return nil, err
@ -126,15 +310,14 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
c.addr = addr
buffer.Truncate(n)
buffer.Advance(3)
return AddressSerializer.ReadAddrPort(buffer)
dest, err := AddressSerializer.ReadAddrPort(buffer)
return dest, err
}
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
_header := buf.StackNew()
header := common.Dup(_header)
header := buf.With(buffer.ExtendHeader(3 + AddressSerializer.AddrPortLen(destination)))
common.Must(header.WriteZeroN(3))
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
buffer = buffer.WriteBufferAtFirst(header)
common.Must(AddressSerializer.WriteAddrPort(header, destination))
return common.Error(c.PacketConn.WriteTo(buffer.Bytes(), c.addr))
}

View file

@ -36,7 +36,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata)
return HandleConnection(ctx, conn, l.authenticator, M.AddrPortFromNetAddr(conn.LocalAddr()).Addr.Addr(), l.handler, metadata)
}
func (l *Listener) Start() error {
@ -131,9 +131,10 @@ func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Aut
if err != nil {
return E.Cause(err, "write socks response")
}
metadata.Protocol = "socks"
metadata.Destination = request.Destination
go func() {
err := handler.NewPacketConnection(NewPacketConn(conn, udpConn), metadata)
err := handler.NewPacketConnection(NewAssociatePacketConn(conn, udpConn, request.Destination), metadata)
if err != nil {
handler.HandleError(err)
}