Add shadowsocks 2022 support

This commit is contained in:
世界 2022-04-11 18:43:17 +08:00
parent 00cd0d4b8f
commit bc80c3357c
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
17 changed files with 740 additions and 310 deletions

View file

@ -1 +0,0 @@
package shadowsocks

View file

@ -27,11 +27,7 @@ func (m *NoneMethod) KeyLength() int {
return 0
}
func (m *NoneMethod) NewSession(key []byte) Session {
return nil
}
func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &noneConn{
Conn: conn,
handshake: true,
@ -40,14 +36,14 @@ func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort)
return shadowsocksConn, shadowsocksConn.clientHandshake()
}
func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn {
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
return &noneConn{
Conn: conn,
destination: destination,
}
}
func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn {
func (m *NoneMethod) DialPacketConn(conn net.Conn) socks.PacketConn {
return &nonePacketConn{conn}
}

View file

@ -2,26 +2,21 @@ package shadowsocks
import (
"crypto/md5"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/protocol/socks"
"hash/crc32"
"io"
"math/rand"
"net"
)
type Session interface {
Key() []byte
ReplayFilter() replay.Filter
}
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
type Method interface {
Name() string
KeyLength() int
DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error)
DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn
DialPacketConn(session Session, conn net.Conn) socks.PacketConn
DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error)
DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
DialPacketConn(conn net.Conn) socks.PacketConn
}
func Key(password []byte, keySize int) []byte {

View file

@ -23,11 +23,11 @@ type Reader struct {
cached int
}
func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
return &Reader{
upstream: upstream,
cipher: cipher,
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
data: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
nonce: make([]byte, cipher.NonceSize()),
}
}
@ -132,18 +132,20 @@ func (r *Reader) Read(b []byte) (n int, err error) {
}
type AEADWriter struct {
upstream io.Writer
cipher cipher.AEAD
data []byte
nonce []byte
upstream io.Writer
cipher cipher.AEAD
data []byte
nonce []byte
maxPacketSize int
}
func NewAEADWriter(upstream io.Writer, cipher cipher.AEAD) *AEADWriter {
func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *AEADWriter {
return &AEADWriter{
upstream: upstream,
cipher: cipher,
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
nonce: make([]byte, cipher.NonceSize()),
upstream: upstream,
cipher: cipher,
data: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
nonce: make([]byte, cipher.NonceSize()),
maxPacketSize: maxPacketSize,
}
}
@ -162,7 +164,7 @@ func (w *AEADWriter) SetWriter(writer io.Writer) {
func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
for {
offset := w.cipher.Overhead() + PacketLengthBufferSize
readN, readErr := r.Read(w.data[offset : offset+MaxPacketSize])
readN, readErr := r.Read(w.data[offset : offset+w.maxPacketSize])
if readErr != nil {
return 0, readErr
}
@ -184,7 +186,11 @@ func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
}
func (w *AEADWriter) Write(p []byte) (n int, err error) {
for _, data := range buf.ForeachN(p, MaxPacketSize) {
if len(p) == 0 {
return
}
for _, data := range buf.ForeachN(p, w.maxPacketSize) {
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(data)))
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce)

View file

@ -4,7 +4,6 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"github.com/sagernet/sing/common/replay"
"io"
"net"
"sync"
@ -13,6 +12,7 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/socks"
@ -28,11 +28,19 @@ var List = []string{
"xchacha20-ietf-poly1305",
}
func New(method string, secureRNG io.Reader) shadowsocks.Method {
var (
ErrBadKey = E.New("bad key")
ErrMissingPassword = E.New("missing password")
)
func New(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool) (shadowsocks.Method, error) {
m := &Method{
name: method,
secureRNG: secureRNG,
}
if replayFilter {
m.replayFilter = replay.NewBloomRing()
}
switch method {
case "aes-128-gcm":
m.keySaltLength = 16
@ -58,15 +66,16 @@ func New(method string, secureRNG io.Reader) shadowsocks.Method {
return cipher
}
}
return m
}
func NewSession(key []byte, replayFilter bool) shadowsocks.Session {
var filter replay.Filter
if replayFilter {
filter = replay.NewBloomRing()
if len(key) == m.keySaltLength {
m.key = key
} else if len(key) > 0 {
return nil, ErrBadKey
} else if len(password) > 0 {
m.key = shadowsocks.Key(password, m.keySaltLength)
} else {
return nil, ErrMissingPassword
}
return &session{key, filter}
return m, nil
}
func Kdf(key, iv []byte, keyLength int) []byte {
@ -88,7 +97,9 @@ type Method struct {
name string
keySaltLength int
constructor func(key []byte) cipher.AEAD
key []byte
secureRNG io.Reader
replayFilter replay.Filter
}
func (m *Method) Name() string {
@ -99,29 +110,25 @@ func (m *Method) KeyLength() int {
return m.keySaltLength
}
func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &aeadConn{
Conn: conn,
method: m,
key: account.Key(),
replayFilter: account.ReplayFilter(),
destination: destination,
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.clientHandshake()
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn {
return &aeadConn{
Conn: conn,
method: m,
key: account.Key(),
replayFilter: account.ReplayFilter(),
destination: destination,
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
return &clientConn{
Conn: conn,
method: m,
destination: destination,
}
}
func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn {
return &aeadPacketConn{conn, account.Key(), m}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &aeadPacketConn{conn, m}
}
func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error {
@ -145,160 +152,133 @@ func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error {
return nil
}
type session struct {
key []byte
replayFilter replay.Filter
}
func (a *session) Key() []byte {
return a.key
}
func (a *session) ReplayFilter() replay.Filter {
return a.replayFilter
}
type aeadConn struct {
type clientConn struct {
net.Conn
method *Method
key []byte
destination *M.AddrPort
access sync.Mutex
reader io.Reader
writer io.Writer
replayFilter replay.Filter
access sync.Mutex
reader io.Reader
writer io.Writer
}
func (c *aeadConn) clientHandshake() error {
header := buf.New()
defer header.Release()
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
if c.replayFilter != nil {
c.replayFilter.Check(header.Bytes())
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
if c.method.replayFilter != nil {
c.method.replayFilter.Check(request.Bytes())
}
c.writer = NewAEADWriter(
&buf.BufferedWriter{
Writer: c.Conn,
Buffer: header,
},
c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)),
var writer io.Writer = c.Conn
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: request,
}
writer = NewWriter(
writer,
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
MaxPacketSize,
)
err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination)
if len(payload) > 0 {
header := buf.New()
defer header.Release()
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
if err != nil {
return err
}
_, err = writer.Write(payload)
if err != nil {
return err
}
} else {
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
if err != nil {
return err
}
}
err := common.FlushVar(&writer)
if err != nil {
return err
}
return common.FlushVar(&c.writer)
c.writer = writer
return nil
}
func (c *aeadConn) serverHandshake() error {
func (c *clientConn) readResponse() error {
if c.reader == nil {
salt := make([]byte, c.method.keySaltLength)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.replayFilter != nil {
if !c.replayFilter.Check(salt) {
if c.method.replayFilter != nil {
if !c.method.replayFilter.Check(salt) {
return E.New("salt is not unique")
}
}
c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength)))
c.reader = NewReader(
c.Conn,
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
MaxPacketSize,
)
}
return nil
}
func (c *aeadConn) Read(p []byte) (n int, err error) {
if err = c.serverHandshake(); err != nil {
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.Read(p)
}
func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.serverHandshake(); err != nil {
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *aeadConn) Write(p []byte) (n int, err error) {
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
goto direct
return c.writer.Write(p)
}
c.access.Lock()
defer c.access.Unlock()
if c.writer != nil {
goto direct
c.access.Unlock()
return c.writer.Write(p)
}
// client handshake
{
header := buf.New()
defer header.Release()
request := buf.New()
defer request.Release()
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
if c.replayFilter != nil {
c.replayFilter.Check(header.Bytes())
}
var writer io.Writer = c.Conn
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)))
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: request,
}
err = socks.AddressSerializer.WriteAddrPort(writer, c.destination)
if err != nil {
return
}
if len(p) > 0 {
_, err = writer.Write(p)
if err != nil {
return
}
}
err = common.FlushVar(&writer)
if err != nil {
return
}
c.writer = writer
return len(p), nil
err = c.writeRequest(p)
if err != nil {
return
}
direct:
return c.writer.Write(p)
return len(p), nil
}
func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) {
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
panic("missing client handshake")
panic("missing handshake")
}
return c.writer.(io.ReaderFrom).ReadFrom(r)
}
type aeadPacketConn struct {
net.Conn
key []byte
method *Method
}
@ -311,7 +291,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
return err
}
buffer = buffer.WriteBufferAtFirst(header)
err = c.method.EncodePacket(c.key, buffer)
err = c.method.EncodePacket(c.method.key, buffer)
if err != nil {
return err
}
@ -324,7 +304,7 @@ func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
return nil, err
}
buffer.Truncate(n)
err = c.method.DecodePacket(c.key, buffer)
err = c.method.DecodePacket(c.method.key, buffer)
if err != nil {
return nil, err
}

View file

@ -0,0 +1,519 @@
package shadowaead_2022
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"io"
"math"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
)
const (
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
KeySaltSize = 32
PacketNonceSize = 24
MinRequestHeaderSize = 1 + 8
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
MaxPacketSize = 64 * 1024
)
var (
ErrBadHeaderType = E.New("bad header type")
ErrBadTimestamp = E.New("bad timestamp")
ErrBadRequestSalt = E.New("bad request salt")
ErrBadClientSessionId = E.New("bad client session id")
ErrPacketIdNotUnique = E.New("packet id not unique")
)
var List = []string{
"2022-blake3-aes-128-gcm",
"2022-blake3-aes-256-gcm",
"2022-blake3-chacha20-poly1305",
}
func New(method string, psk []byte, secureRNG io.Reader) (shadowsocks.Method, error) {
m := &Method{
name: method,
key: psk,
secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(30),
}
if len(psk) != KeySaltSize {
return nil, shadowaead.ErrBadKey
}
switch method {
case "2022-blake3-aes-128-gcm":
m.keyLength = 16
m.constructor = newAESGCM
m.udpBlockConstructor = newAES
case "2022-blake3-aes-256-gcm":
m.keyLength = 32
m.constructor = newAESGCM
m.udpBlockConstructor = newAES
case "2022-blake3-chacha20-poly1305":
m.keyLength = 32
m.constructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
return cipher
}
m.udpConstructor = func(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.NewX(key)
common.Must(err)
return cipher
}
}
return m, nil
}
func Blake3DeriveKey(secret, salt, outKey []byte) {
sessionKey := make([]byte, len(secret)+len(salt))
copy(sessionKey, secret)
copy(sessionKey[len(secret):], salt)
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
}
func newAES(key []byte) cipher.Block {
block, err := aes.NewCipher(key)
common.Must(err)
return block
}
func newAESGCM(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
return aead
}
type Method struct {
name string
keyLength int
constructor func(key []byte) cipher.AEAD
udpBlockConstructor func(key []byte) cipher.Block
udpConstructor func(key []byte) cipher.AEAD
key []byte
secureRNG io.Reader
replayFilter replay.Filter
}
func (m *Method) Name() string {
return m.name
}
func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
method: m,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
}
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
return &clientConn{
Conn: conn,
method: m,
destination: destination,
}
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &clientPacketConn{conn, m, newUDPSession()}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
if m.udpConstructor == nil {
// aes
packetHeader := buffer.To(aes.BlockSize)
subKey := make([]byte, m.keyLength)
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
cipher := m.constructor(subKey)
cipher.Seal(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
buffer.Extend(cipher.Overhead())
m.udpBlockConstructor(m.key).Encrypt(packetHeader, packetHeader)
} else {
// xchacha
cipher := m.udpConstructor(m.key)
cipher.Seal(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
buffer.Extend(cipher.Overhead())
}
return nil
}
func (m *Method) DecodePacket(buffer *buf.Buffer) error {
if m.udpBlockConstructor != nil {
if buffer.Len() <= aes.BlockSize {
return E.New("insufficient data: ", buffer.Len())
}
packetHeader := buffer.To(aes.BlockSize)
m.udpBlockConstructor(m.key).Decrypt(packetHeader, packetHeader)
subKey := make([]byte, m.keyLength)
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
_, err := m.constructor(subKey).Open(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
if err != nil {
return err
}
} else {
_, err := m.udpConstructor(m.key).Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
if err != nil {
return err
}
buffer.Advance(PacketNonceSize)
}
return nil
}
type clientConn struct {
net.Conn
method *Method
destination *M.AddrPort
request sync.Mutex
response sync.Mutex
requestSalt []byte
reader io.Reader
writer io.Writer
}
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
salt := make([]byte, KeySaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt))
c.method.replayFilter.Check(salt)
common.Must1(request.Write(salt))
subKey := make([]byte, c.method.keyLength)
Blake3DeriveKey(c.method.key, salt, subKey)
var writer io.Writer = c.Conn
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: request,
}
writer = shadowaead.NewWriter(
writer,
c.method.constructor(subKey),
MaxPacketSize,
)
header := buf.New()
defer header.Release()
writer = &buf.BufferedWriter{
Writer: writer,
Buffer: header,
}
common.Must(rw.WriteByte(writer, HeaderTypeClient))
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
if err != nil {
return E.Cause(err, "write destination")
}
if len(payload) > 0 {
err = binary.Write(writer, binary.BigEndian, uint16(0))
if err != nil {
return E.Cause(err, "write padding length")
}
_, err = writer.Write(payload)
if err != nil {
return E.Cause(err, "write payload")
}
} else {
pLen := rand.Intn(MaxPaddingLength)
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
if err != nil {
return E.Cause(err, "write padding length")
}
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
if err != nil {
return E.Cause(err, "write padding")
}
}
err = common.FlushVar(&writer)
if err != nil {
return E.Cause(err, "client handshake")
}
c.requestSalt = salt
c.writer = writer
return nil
}
func (c *clientConn) readResponse() error {
if c.reader != nil {
return nil
}
c.response.Lock()
defer c.response.Unlock()
if c.reader != nil {
return nil
}
salt := make([]byte, KeySaltSize)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if !c.method.replayFilter.Check(salt) {
return E.New("salt is not unique")
}
subKey := make([]byte, c.method.keyLength)
Blake3DeriveKey(c.method.key, salt, subKey)
reader := shadowaead.NewReader(
c.Conn,
c.method.constructor(subKey),
MaxPacketSize,
)
headerType, err := rw.ReadByte(reader)
if err != nil {
return err
}
if headerType != HeaderTypeServer {
return ErrBadHeaderType
}
var epoch uint64
err = binary.Read(reader, binary.BigEndian, &epoch)
if err != nil {
return err
}
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
return ErrBadTimestamp
}
requestSalt := make([]byte, KeySaltSize)
_, err = io.ReadFull(reader, requestSalt)
if err != nil {
return err
}
if bytes.Compare(requestSalt, c.requestSalt) > 0 {
return ErrBadRequestSalt
}
c.reader = reader
return nil
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.Read(p)
}
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.(io.WriterTo).WriteTo(w)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writer != nil {
return c.writer.Write(p)
}
c.request.Lock()
if c.writer != nil {
c.request.Unlock()
return c.writer.Write(p)
}
defer c.request.Unlock()
err = c.writeRequest(p)
if err == nil {
n = len(p)
}
return
}
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
panic("missing client handshake")
}
return c.writer.(io.ReaderFrom).ReadFrom(r)
}
type clientPacketConn struct {
net.Conn
method *Method
session *udpSession
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
if c.method.udpConstructor != nil {
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
}
common.Must(
binary.Write(header, binary.BigEndian, c.session.sessionId),
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
header.WriteByte(HeaderTypeClient),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
)
c.session.filter.ValidateCounter(c.session.packetId, math.MaxUint64)
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
buffer = buffer.WriteBufferAtFirst(header)
err = c.method.EncodePacket(buffer)
if err != nil {
return err
}
return common.Error(c.Write(buffer.Bytes()))
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err
}
buffer.Truncate(n)
err = c.method.DecodePacket(buffer)
if err != nil {
return nil, err
}
var sessionId uint64
err = binary.Read(buffer, binary.BigEndian, &sessionId)
if err != nil {
return nil, err
}
var isLastSessionId bool
if c.session.remoteSessionId == 0 {
c.session.remoteSessionId = sessionId
} else if sessionId != c.session.remoteSessionId {
if sessionId == c.session.lastRemoteSessionId {
isLastSessionId = true
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.remoteSessionId = sessionId
c.session.lastFilter = c.session.filter
c.session.filter = new(wgReplay.Filter)
}
}
var packetId uint64
err = binary.Read(buffer, binary.BigEndian, &packetId)
if err != nil {
return nil, err
}
if !isLastSessionId {
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
}
} else {
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
return nil, ErrPacketIdNotUnique
}
}
headerType, err := buffer.ReadBytes(1)
if err != nil {
return nil, err
}
if headerType[0] != HeaderTypeServer {
return nil, ErrBadHeaderType
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return nil, err
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return nil, ErrBadTimestamp
}
var clientSessionId uint64
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
if err != nil {
return nil, err
}
if clientSessionId != c.session.sessionId {
return nil, ErrBadClientSessionId
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return nil, E.Cause(err, "read padding length")
}
buffer.Advance(int(paddingLength))
return socks.AddressSerializer.ReadAddrPort(buffer)
}
type udpSession struct {
headerType byte
sessionId uint64
packetId uint64
remoteSessionId uint64
lastRemoteSessionId uint64
filter *wgReplay.Filter
lastFilter *wgReplay.Filter
}
func (s *udpSession) nextPacketId() uint64 {
return atomic.AddUint64(&s.packetId, 1)
}
func newUDPSession() *udpSession {
return &udpSession{
sessionId: rand.Uint64(),
filter: new(wgReplay.Filter),
}
}

View file

@ -18,6 +18,7 @@ type Handler interface {
}
type Listener struct {
bindAddr netip.Addr
tcpListener *tcp.Listener
authenticator auth.Authenticator
handler Handler
@ -25,6 +26,7 @@ type Listener struct {
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener {
listener := &Listener{
bindAddr: bind.Addr(),
handler: handler,
authenticator: authenticator,
}
@ -33,7 +35,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
}
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
return HandleConnection(conn, l.authenticator, l.handler)
return HandleConnection(conn, l.bindAddr, l.authenticator, l.handler)
}
func (l *Listener) Start() error {
@ -48,7 +50,7 @@ func (l *Listener) HandleError(err error) {
l.handler.HandleError(err)
}
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error {
func HandleConnection(conn net.Conn, bind netip.Addr, authenticator auth.Authenticator, handler Handler) error {
authRequest, err := ReadAuthRequest(conn)
if err != nil {
return E.Cause(err, "read socks auth request")
@ -111,7 +113,11 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler H
Destination: request.Destination,
})
case CommandUDPAssociate:
udpConn, err := net.ListenUDP("udp", nil)
network := "udp"
if bind.Is4() {
network = "udp4"
}
udpConn, err := net.ListenUDP(network, net.UDPAddrFromAddrPort(netip.AddrPortFrom(bind, 0)))
if err != nil {
return err
}