mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-03 20:07:40 +03:00
Init commit
This commit is contained in:
commit
48809b0a99
20 changed files with 3828 additions and 0 deletions
421
shadowaead/aead.go
Normal file
421
shadowaead/aead.go
Normal file
|
@ -0,0 +1,421 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// https://shadowsocks.org/en/wiki/AEAD-Ciphers.html
|
||||
const (
|
||||
MaxPacketSize = 16*1024 - 1
|
||||
PacketLengthBufferSize = 2
|
||||
)
|
||||
|
||||
const (
|
||||
// NonceSize
|
||||
// crypto/cipher.gcmStandardNonceSize
|
||||
// golang.org/x/crypto/chacha20poly1305.NonceSize
|
||||
NonceSize = 12
|
||||
|
||||
// Overhead
|
||||
// crypto/cipher.gcmTagSize
|
||||
// golang.org/x/crypto/chacha20poly1305.Overhead
|
||||
Overhead = 16
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
upstream io.Reader
|
||||
cipher cipher.AEAD
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
index int
|
||||
cached int
|
||||
}
|
||||
|
||||
func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
|
||||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
|
||||
nonce: make([]byte, NonceSize),
|
||||
}
|
||||
}
|
||||
|
||||
func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader {
|
||||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
buffer: buffer,
|
||||
nonce: nonce,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Upstream() any {
|
||||
return r.upstream
|
||||
}
|
||||
|
||||
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
if r.cached > 0 {
|
||||
writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
|
||||
if writeErr != nil {
|
||||
return int64(writeN), writeErr
|
||||
}
|
||||
n += int64(writeN)
|
||||
}
|
||||
for {
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
writeN, writeErr := writer.Write(r.buffer[:length])
|
||||
if writeErr != nil {
|
||||
return int64(writeN), writeErr
|
||||
}
|
||||
n += int64(writeN)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) readInternal() (err error) {
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = length
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadByte() (byte, error) {
|
||||
if r.cached == 0 {
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
index := r.index
|
||||
r.index++
|
||||
r.cached--
|
||||
return r.buffer[index], nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.buffer[r.index:r.index+r.cached])
|
||||
r.cached -= n
|
||||
r.index += n
|
||||
return
|
||||
}
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + Overhead
|
||||
|
||||
if len(b) >= end {
|
||||
data := b[:end]
|
||||
_, err = io.ReadFull(r.upstream, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.cipher.Open(b[:0], r.nonce, data, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
return length, nil
|
||||
} else {
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
n = copy(b, r.buffer[:length])
|
||||
r.cached = length - n
|
||||
r.index = n
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Discard(n int) error {
|
||||
for {
|
||||
if r.cached >= n {
|
||||
r.cached -= n
|
||||
r.index += n
|
||||
return nil
|
||||
} else if r.cached > 0 {
|
||||
n -= r.cached
|
||||
r.cached = 0
|
||||
r.index = 0
|
||||
}
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Cached() int {
|
||||
return r.cached
|
||||
}
|
||||
|
||||
func (r *Reader) CachedSlice() []byte {
|
||||
return r.buffer[r.index : r.index+r.cached]
|
||||
}
|
||||
|
||||
func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error {
|
||||
_, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = length
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadWithLength(length uint16) error {
|
||||
end := length + Overhead
|
||||
_, err := io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = int(length)
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadChunk(chunk []byte) error {
|
||||
bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = len(bb)
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
maxPacketSize int
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
}
|
||||
|
||||
func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer {
|
||||
return &Writer{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
maxPacketSize: maxPacketSize,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer {
|
||||
return &Writer{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
maxPacketSize: maxPacketSize,
|
||||
buffer: buffer,
|
||||
nonce: nonce,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Upstream() any {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize])
|
||||
if readErr != nil {
|
||||
return 0, readErr
|
||||
}
|
||||
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN))
|
||||
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil)
|
||||
increaseNonce(w.nonce)
|
||||
_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(readN)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for pLen := len(p); pLen > 0; {
|
||||
var data []byte
|
||||
if pLen > w.maxPacketSize {
|
||||
data = p[:w.maxPacketSize]
|
||||
p = p[w.maxPacketSize:]
|
||||
} else {
|
||||
data = p
|
||||
p = nil
|
||||
}
|
||||
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
|
||||
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
|
||||
increaseNonce(w.nonce)
|
||||
_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += len(data)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (w *Writer) Buffer() *buf.Buffer {
|
||||
return buf.With(w.buffer)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) {
|
||||
bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil)
|
||||
buffer.Extend(len(bb))
|
||||
increaseNonce(w.nonce)
|
||||
}
|
||||
|
||||
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
upstream: w,
|
||||
reversed: reversed,
|
||||
data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead],
|
||||
}
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
upstream *Writer
|
||||
data []byte
|
||||
reversed int
|
||||
index int
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) UpstreamWriter() io.Writer {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) WriterReplaceable() bool {
|
||||
return w.index == 0
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
var index int
|
||||
for {
|
||||
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
|
||||
if cachedN == len(p[index:]) {
|
||||
w.index += cachedN
|
||||
return cachedN, nil
|
||||
}
|
||||
err = w.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
index += cachedN
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if w.index == 0 {
|
||||
if w.reversed > 0 {
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
buffer := w.upstream.buffer[w.reversed:]
|
||||
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
|
||||
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
|
||||
func increaseNonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
361
shadowaead/protocol.go
Normal file
361
shadowaead/protocol.go
Normal file
|
@ -0,0 +1,361 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var List = []string{
|
||||
"aes-128-gcm",
|
||||
"aes-192-gcm",
|
||||
"aes-256-gcm",
|
||||
"chacha20-ietf-poly1305",
|
||||
"xchacha20-ietf-poly1305",
|
||||
}
|
||||
|
||||
func New(method string, key []byte, password string) (shadowsocks.Method, error) {
|
||||
m := &Method{
|
||||
name: method,
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
m.keySaltLength = 16
|
||||
m.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
m.keySaltLength = 24
|
||||
m.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
m.keySaltLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
if len(key) == m.keySaltLength {
|
||||
m.key = key
|
||||
} else if len(key) > 0 {
|
||||
return nil, shadowsocks.ErrBadKey
|
||||
} else if password == "" {
|
||||
return nil, shadowsocks.ErrMissingPassword
|
||||
} else {
|
||||
m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
info := []byte("ss-subkey")
|
||||
subKey := buf.Make(keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, common.Dup(info))
|
||||
runtime.KeepAlive(info)
|
||||
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
|
||||
return subKey
|
||||
}
|
||||
|
||||
func newAESGCM(key []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
common.Must(err)
|
||||
aead, err := cipher.NewGCM(block)
|
||||
common.Must(err)
|
||||
return aead
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Method) KeyLength() int {
|
||||
return m.keySaltLength
|
||||
}
|
||||
|
||||
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(upstream, salt)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read salt")
|
||||
}
|
||||
key := Kdf(m.key, salt, m.keySaltLength)
|
||||
defer runtime.KeepAlive(key)
|
||||
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||
}
|
||||
|
||||
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
_, err := upstream.Write(salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := Kdf(m.key, salt, m.keySaltLength)
|
||||
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
||||
shadowsocksConn := &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
||||
return &clientConn{
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||
return &clientPacketConn{m, conn}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(Overhead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
||||
if buffer.Len() < m.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(m.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
return nil
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
net.Conn
|
||||
method *Method
|
||||
destination M.Socksaddr
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_salt := make([]byte, c.method.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
runtime.KeepAlive(_salt)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
runtime.KeepAlive(key)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
if len(payload) > 0 {
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) readResponse() error {
|
||||
if c.reader != nil {
|
||||
return nil
|
||||
}
|
||||
_salt := buf.Make(c.method.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
defer runtime.KeepAlive(key)
|
||||
c.reader = NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
|
||||
err = c.writeRequest(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength]))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
err = c.DecodePacket(buffer)
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
return M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = c.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b := buf.With(p[:n])
|
||||
err = c.DecodePacket(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr = destination.UDPAddr()
|
||||
n = copy(p, b.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
defer runtime.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = c.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = c.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
246
shadowaead/service.go
Normal file
246
shadowaead/service.go
Normal file
|
@ -0,0 +1,246 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
var ErrBadHeader = E.New("bad header")
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
handler shadowsocks.Handler
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
}
|
||||
|
||||
func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
handler: handler,
|
||||
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
s.keySaltLength = 16
|
||||
s.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
s.keySaltLength = 24
|
||||
s.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
if len(key) == s.keySaltLength {
|
||||
s.key = key
|
||||
} else if len(key) > 0 {
|
||||
return nil, shadowsocks.ErrBadKey
|
||||
} else if password != "" {
|
||||
s.key = shadowsocks.Key([]byte(password), s.keySaltLength)
|
||||
} else {
|
||||
return nil, shadowsocks.ErrMissingPassword
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
err := s.newConnection(ctx, conn, metadata)
|
||||
if err != nil {
|
||||
err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
_header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead)
|
||||
defer runtime.KeepAlive(_header)
|
||||
header := common.Dup(_header)
|
||||
|
||||
n, err := conn.Read(header)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read header")
|
||||
} else if n < len(header) {
|
||||
return ErrBadHeader
|
||||
}
|
||||
|
||||
key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength)
|
||||
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
|
||||
|
||||
err = reader.ReadWithLengthChunk(header[s.keySaltLength:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
|
||||
return s.handler.NewConnection(ctx, &serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
*Service
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||
_salt := buf.Make(c.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
key := Kdf(c.key, salt, c.keySaltLength)
|
||||
runtime.KeepAlive(_salt)
|
||||
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
runtime.KeepAlive(key)
|
||||
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
c.access.Lock()
|
||||
if c.writer != nil {
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.writeResponse(p)
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *serverConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
err := s.newPacket(ctx, conn, buffer, metadata)
|
||||
if err != nil {
|
||||
err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
if buffer.Len() < s.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
|
||||
c := s.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(s.keySaltLength)
|
||||
buffer.Truncate(len(packet))
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter {
|
||||
return &serverPacketWriter{s, conn, metadata.Source}
|
||||
}, buffer, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
type serverPacketWriter struct {
|
||||
*Service
|
||||
N.PacketConn
|
||||
source M.Socksaddr
|
||||
}
|
||||
|
||||
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength]))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
|
||||
c := w.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
|
||||
buffer.Extend(Overhead)
|
||||
return w.PacketConn.WritePacket(buffer, w.source)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue