mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 04:17:38 +03:00
Add shadowsocks 2022 support
This commit is contained in:
parent
00cd0d4b8f
commit
bc80c3357c
17 changed files with 740 additions and 310 deletions
|
@ -1 +0,0 @@
|
|||
package shadowsocks
|
|
@ -27,11 +27,7 @@ func (m *NoneMethod) KeyLength() int {
|
|||
return 0
|
||||
}
|
||||
|
||||
func (m *NoneMethod) NewSession(key []byte) Session {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
func (m *NoneMethod) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &noneConn{
|
||||
Conn: conn,
|
||||
handshake: true,
|
||||
|
@ -40,14 +36,14 @@ func (m *NoneMethod) DialConn(_ Session, conn net.Conn, destination *M.AddrPort)
|
|||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialEarlyConn(_ Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &noneConn{
|
||||
Conn: conn,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *NoneMethod) DialPacketConn(_ Session, conn net.Conn) socks.PacketConn {
|
||||
func (m *NoneMethod) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &nonePacketConn{conn}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,26 +2,21 @@ package shadowsocks
|
|||
|
||||
import (
|
||||
"crypto/md5"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
Key() []byte
|
||||
ReplayFilter() replay.Filter
|
||||
}
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Method interface {
|
||||
Name() string
|
||||
KeyLength() int
|
||||
DialConn(session Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error)
|
||||
DialEarlyConn(session Session, conn net.Conn, destination *M.AddrPort) net.Conn
|
||||
DialPacketConn(session Session, conn net.Conn) socks.PacketConn
|
||||
DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error)
|
||||
DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
||||
DialPacketConn(conn net.Conn) socks.PacketConn
|
||||
}
|
||||
|
||||
func Key(password []byte, keySize int) []byte {
|
||||
|
|
|
@ -23,11 +23,11 @@ type Reader struct {
|
|||
cached int
|
||||
}
|
||||
|
||||
func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
|
||||
func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
|
||||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
data: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
@ -132,18 +132,20 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
type AEADWriter struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
nonce []byte
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
data []byte
|
||||
nonce []byte
|
||||
maxPacketSize int
|
||||
}
|
||||
|
||||
func NewAEADWriter(upstream io.Writer, cipher cipher.AEAD) *AEADWriter {
|
||||
func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *AEADWriter {
|
||||
return &AEADWriter{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, MaxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
data: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
maxPacketSize: maxPacketSize,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,7 +164,7 @@ func (w *AEADWriter) SetWriter(writer io.Writer) {
|
|||
func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
offset := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
readN, readErr := r.Read(w.data[offset : offset+MaxPacketSize])
|
||||
readN, readErr := r.Read(w.data[offset : offset+w.maxPacketSize])
|
||||
if readErr != nil {
|
||||
return 0, readErr
|
||||
}
|
||||
|
@ -184,7 +186,11 @@ func (w *AEADWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
func (w *AEADWriter) Write(p []byte) (n int, err error) {
|
||||
for _, data := range buf.ForeachN(p, MaxPacketSize) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, data := range buf.ForeachN(p, w.maxPacketSize) {
|
||||
binary.BigEndian.PutUint16(w.data[:PacketLengthBufferSize], uint16(len(data)))
|
||||
w.cipher.Seal(w.data[:0], w.nonce, w.data[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -13,6 +12,7 @@ import (
|
|||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
|
@ -28,11 +28,19 @@ var List = []string{
|
|||
"xchacha20-ietf-poly1305",
|
||||
}
|
||||
|
||||
func New(method string, secureRNG io.Reader) shadowsocks.Method {
|
||||
var (
|
||||
ErrBadKey = E.New("bad key")
|
||||
ErrMissingPassword = E.New("missing password")
|
||||
)
|
||||
|
||||
func New(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool) (shadowsocks.Method, error) {
|
||||
m := &Method{
|
||||
name: method,
|
||||
secureRNG: secureRNG,
|
||||
}
|
||||
if replayFilter {
|
||||
m.replayFilter = replay.NewBloomRing()
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
m.keySaltLength = 16
|
||||
|
@ -58,15 +66,16 @@ func New(method string, secureRNG io.Reader) shadowsocks.Method {
|
|||
return cipher
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func NewSession(key []byte, replayFilter bool) shadowsocks.Session {
|
||||
var filter replay.Filter
|
||||
if replayFilter {
|
||||
filter = replay.NewBloomRing()
|
||||
if len(key) == m.keySaltLength {
|
||||
m.key = key
|
||||
} else if len(key) > 0 {
|
||||
return nil, ErrBadKey
|
||||
} else if len(password) > 0 {
|
||||
m.key = shadowsocks.Key(password, m.keySaltLength)
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return &session{key, filter}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
|
@ -88,7 +97,9 @@ type Method struct {
|
|||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
|
@ -99,29 +110,25 @@ func (m *Method) KeyLength() int {
|
|||
return m.keySaltLength
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.clientHandshake()
|
||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(account shadowsocks.Session, conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &aeadConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
key: account.Key(),
|
||||
replayFilter: account.ReplayFilter(),
|
||||
destination: destination,
|
||||
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialPacketConn(account shadowsocks.Session, conn net.Conn) socks.PacketConn {
|
||||
return &aeadPacketConn{conn, account.Key(), m}
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &aeadPacketConn{conn, m}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(key []byte, buffer *buf.Buffer) error {
|
||||
|
@ -145,160 +152,133 @@ func (m *Method) DecodePacket(key []byte, buffer *buf.Buffer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type session struct {
|
||||
key []byte
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (a *session) Key() []byte {
|
||||
return a.key
|
||||
}
|
||||
|
||||
func (a *session) ReplayFilter() replay.Filter {
|
||||
return a.replayFilter
|
||||
}
|
||||
|
||||
type aeadConn struct {
|
||||
type clientConn struct {
|
||||
net.Conn
|
||||
|
||||
method *Method
|
||||
key []byte
|
||||
destination *M.AddrPort
|
||||
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
replayFilter replay.Filter
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func (c *aeadConn) clientHandshake() error {
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.method.replayFilter != nil {
|
||||
c.method.replayFilter.Check(request.Bytes())
|
||||
}
|
||||
|
||||
c.writer = NewAEADWriter(
|
||||
&buf.BufferedWriter{
|
||||
Writer: c.Conn,
|
||||
Buffer: header,
|
||||
},
|
||||
c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)),
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
writer = NewWriter(
|
||||
writer,
|
||||
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(c.writer, c.destination)
|
||||
if len(payload) > 0 {
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = writer.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := common.FlushVar(&writer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return common.FlushVar(&c.writer)
|
||||
c.writer = writer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) serverHandshake() error {
|
||||
func (c *clientConn) readResponse() error {
|
||||
if c.reader == nil {
|
||||
salt := make([]byte, c.method.keySaltLength)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.replayFilter != nil {
|
||||
if !c.replayFilter.Check(salt) {
|
||||
if c.method.replayFilter != nil {
|
||||
if !c.method.replayFilter.Check(salt) {
|
||||
return E.New("salt is not unique")
|
||||
}
|
||||
}
|
||||
c.reader = NewReader(c.Conn, c.method.constructor(Kdf(c.key, salt, c.method.keySaltLength)))
|
||||
c.reader = NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *aeadConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.serverHandshake(); err != nil {
|
||||
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(p []byte) (n int, err error) {
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.writer != nil {
|
||||
goto direct
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
// client handshake
|
||||
|
||||
{
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.replayFilter != nil {
|
||||
c.replayFilter.Check(header.Bytes())
|
||||
}
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
writer = NewAEADWriter(writer, c.method.constructor(Kdf(c.key, header.Bytes(), c.method.keySaltLength)))
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(p) > 0 {
|
||||
_, err = writer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return len(p), nil
|
||||
err = c.writeRequest(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
direct:
|
||||
return c.writer.Write(p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
panic("missing handshake")
|
||||
}
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
type aeadPacketConn struct {
|
||||
net.Conn
|
||||
key []byte
|
||||
method *Method
|
||||
}
|
||||
|
||||
|
@ -311,7 +291,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
|
|||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(c.key, buffer)
|
||||
err = c.method.EncodePacket(c.method.key, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -324,7 +304,7 @@ func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.method.DecodePacket(c.key, buffer)
|
||||
err = c.method.DecodePacket(c.method.key, buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
519
protocol/shadowsocks/shadowaead_2022/protocol.go
Normal file
519
protocol/shadowsocks/shadowaead_2022/protocol.go
Normal file
|
@ -0,0 +1,519 @@
|
|||
package shadowaead_2022
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||
"lukechampine.com/blake3"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderTypeClient = 0
|
||||
HeaderTypeServer = 1
|
||||
MaxPaddingLength = 900
|
||||
KeySaltSize = 32
|
||||
PacketNonceSize = 24
|
||||
MinRequestHeaderSize = 1 + 8
|
||||
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
|
||||
MaxPacketSize = 64 * 1024
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadHeaderType = E.New("bad header type")
|
||||
ErrBadTimestamp = E.New("bad timestamp")
|
||||
ErrBadRequestSalt = E.New("bad request salt")
|
||||
ErrBadClientSessionId = E.New("bad client session id")
|
||||
ErrPacketIdNotUnique = E.New("packet id not unique")
|
||||
)
|
||||
|
||||
var List = []string{
|
||||
"2022-blake3-aes-128-gcm",
|
||||
"2022-blake3-aes-256-gcm",
|
||||
"2022-blake3-chacha20-poly1305",
|
||||
}
|
||||
|
||||
func New(method string, psk []byte, secureRNG io.Reader) (shadowsocks.Method, error) {
|
||||
m := &Method{
|
||||
name: method,
|
||||
key: psk,
|
||||
secureRNG: secureRNG,
|
||||
replayFilter: replay.NewCuckoo(30),
|
||||
}
|
||||
|
||||
if len(psk) != KeySaltSize {
|
||||
return nil, shadowaead.ErrBadKey
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "2022-blake3-aes-128-gcm":
|
||||
m.keyLength = 16
|
||||
m.constructor = newAESGCM
|
||||
m.udpBlockConstructor = newAES
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
m.keyLength = 32
|
||||
m.constructor = newAESGCM
|
||||
m.udpBlockConstructor = newAES
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
m.keyLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
m.udpConstructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func Blake3DeriveKey(secret, salt, outKey []byte) {
|
||||
sessionKey := make([]byte, len(secret)+len(salt))
|
||||
copy(sessionKey, secret)
|
||||
copy(sessionKey[len(secret):], salt)
|
||||
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
|
||||
}
|
||||
|
||||
func newAES(key []byte) cipher.Block {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
return block
|
||||
}
|
||||
|
||||
func newAESGCM(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
udpBlockConstructor func(key []byte) cipher.Block
|
||||
udpConstructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Method) KeyLength() int {
|
||||
return m.keyLength
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn {
|
||||
return &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &clientPacketConn{conn, m, newUDPSession()}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
if m.udpConstructor == nil {
|
||||
// aes
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
subKey := make([]byte, m.keyLength)
|
||||
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
|
||||
|
||||
cipher := m.constructor(subKey)
|
||||
cipher.Seal(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
m.udpBlockConstructor(m.key).Encrypt(packetHeader, packetHeader)
|
||||
} else {
|
||||
// xchacha
|
||||
cipher := m.udpConstructor(m.key)
|
||||
cipher.Seal(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
||||
if m.udpBlockConstructor != nil {
|
||||
if buffer.Len() <= aes.BlockSize {
|
||||
return E.New("insufficient data: ", buffer.Len())
|
||||
}
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
m.udpBlockConstructor(m.key).Decrypt(packetHeader, packetHeader)
|
||||
subKey := make([]byte, m.keyLength)
|
||||
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
|
||||
_, err := m.constructor(subKey).Open(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err := m.udpConstructor(m.key).Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
net.Conn
|
||||
|
||||
method *Method
|
||||
destination *M.AddrPort
|
||||
|
||||
request sync.Mutex
|
||||
response sync.Mutex
|
||||
|
||||
requestSalt []byte
|
||||
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
c.method.replayFilter.Check(salt)
|
||||
common.Must1(request.Write(salt))
|
||||
|
||||
subKey := make([]byte, c.method.keyLength)
|
||||
Blake3DeriveKey(c.method.key, salt, subKey)
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
writer = shadowaead.NewWriter(
|
||||
writer,
|
||||
c.method.constructor(subKey),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
|
||||
common.Must(rw.WriteByte(writer, HeaderTypeClient))
|
||||
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write destination")
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(0))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = writer.Write(payload)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
} else {
|
||||
pLen := rand.Intn(MaxPaddingLength)
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding")
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
c.requestSalt = salt
|
||||
c.writer = writer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) readResponse() error {
|
||||
if c.reader != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.response.Lock()
|
||||
defer c.response.Unlock()
|
||||
|
||||
if c.reader != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.method.replayFilter.Check(salt) {
|
||||
return E.New("salt is not unique")
|
||||
}
|
||||
|
||||
subKey := make([]byte, c.method.keyLength)
|
||||
Blake3DeriveKey(c.method.key, salt, subKey)
|
||||
|
||||
reader := shadowaead.NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(subKey),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if headerType != HeaderTypeServer {
|
||||
return ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(reader, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
requestSalt := make([]byte, KeySaltSize)
|
||||
_, err = io.ReadFull(reader, requestSalt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bytes.Compare(requestSalt, c.requestSalt) > 0 {
|
||||
return ErrBadRequestSalt
|
||||
}
|
||||
|
||||
c.reader = reader
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
c.request.Lock()
|
||||
|
||||
if c.writer != nil {
|
||||
c.request.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
defer c.request.Unlock()
|
||||
|
||||
err = c.writeRequest(p)
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
}
|
||||
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
net.Conn
|
||||
method *Method
|
||||
session *udpSession
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
if c.method.udpConstructor != nil {
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
||||
}
|
||||
common.Must(
|
||||
binary.Write(header, binary.BigEndian, c.session.sessionId),
|
||||
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
|
||||
header.WriteByte(HeaderTypeClient),
|
||||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
|
||||
)
|
||||
c.session.filter.ValidateCounter(c.session.packetId, math.MaxUint64)
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
|
||||
err = c.method.DecodePacket(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sessionId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var isLastSessionId bool
|
||||
if c.session.remoteSessionId == 0 {
|
||||
c.session.remoteSessionId = sessionId
|
||||
} else if sessionId != c.session.remoteSessionId {
|
||||
if sessionId == c.session.lastRemoteSessionId {
|
||||
isLastSessionId = true
|
||||
} else {
|
||||
c.session.lastRemoteSessionId = c.session.remoteSessionId
|
||||
c.session.remoteSessionId = sessionId
|
||||
c.session.lastFilter = c.session.filter
|
||||
c.session.filter = new(wgReplay.Filter)
|
||||
}
|
||||
}
|
||||
|
||||
var packetId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !isLastSessionId {
|
||||
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
}
|
||||
} else {
|
||||
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
}
|
||||
}
|
||||
|
||||
headerType, err := buffer.ReadBytes(1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerType[0] != HeaderTypeServer {
|
||||
return nil, ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
|
||||
return nil, ErrBadTimestamp
|
||||
}
|
||||
|
||||
var clientSessionId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if clientSessionId != c.session.sessionId {
|
||||
return nil, ErrBadClientSessionId
|
||||
}
|
||||
|
||||
var paddingLength uint16
|
||||
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read padding length")
|
||||
}
|
||||
buffer.Advance(int(paddingLength))
|
||||
return socks.AddressSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
type udpSession struct {
|
||||
headerType byte
|
||||
sessionId uint64
|
||||
packetId uint64
|
||||
remoteSessionId uint64
|
||||
lastRemoteSessionId uint64
|
||||
filter *wgReplay.Filter
|
||||
lastFilter *wgReplay.Filter
|
||||
}
|
||||
|
||||
func (s *udpSession) nextPacketId() uint64 {
|
||||
return atomic.AddUint64(&s.packetId, 1)
|
||||
}
|
||||
|
||||
func newUDPSession() *udpSession {
|
||||
return &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
filter: new(wgReplay.Filter),
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ type Handler interface {
|
|||
}
|
||||
|
||||
type Listener struct {
|
||||
bindAddr netip.Addr
|
||||
tcpListener *tcp.Listener
|
||||
authenticator auth.Authenticator
|
||||
handler Handler
|
||||
|
@ -25,6 +26,7 @@ type Listener struct {
|
|||
|
||||
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler Handler) *Listener {
|
||||
listener := &Listener{
|
||||
bindAddr: bind.Addr(),
|
||||
handler: handler,
|
||||
authenticator: authenticator,
|
||||
}
|
||||
|
@ -33,7 +35,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
|
|||
}
|
||||
|
||||
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
return HandleConnection(conn, l.authenticator, l.handler)
|
||||
return HandleConnection(conn, l.bindAddr, l.authenticator, l.handler)
|
||||
}
|
||||
|
||||
func (l *Listener) Start() error {
|
||||
|
@ -48,7 +50,7 @@ func (l *Listener) HandleError(err error) {
|
|||
l.handler.HandleError(err)
|
||||
}
|
||||
|
||||
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler Handler) error {
|
||||
func HandleConnection(conn net.Conn, bind netip.Addr, authenticator auth.Authenticator, handler Handler) error {
|
||||
authRequest, err := ReadAuthRequest(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read socks auth request")
|
||||
|
@ -111,7 +113,11 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, handler H
|
|||
Destination: request.Destination,
|
||||
})
|
||||
case CommandUDPAssociate:
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
network := "udp"
|
||||
if bind.Is4() {
|
||||
network = "udp4"
|
||||
}
|
||||
udpConn, err := net.ListenUDP(network, net.UDPAddrFromAddrPort(netip.AddrPortFrom(bind, 0)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue