mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Shadowsocks 2022 protocol change
This commit is contained in:
parent
7d06e20837
commit
04cd76f0aa
11 changed files with 294 additions and 115 deletions
|
@ -50,14 +50,9 @@ func From(data []byte) *Buffer {
|
|||
}
|
||||
|
||||
func As(data []byte) *Buffer {
|
||||
size := len(data)
|
||||
max := cap(data)
|
||||
if size != max {
|
||||
data = data[:max]
|
||||
}
|
||||
return &Buffer{
|
||||
data: data,
|
||||
end: size,
|
||||
end: len(data),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package common
|
|||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
@ -117,20 +116,20 @@ func Error(_ any, err error) error {
|
|||
func Must(errs ...error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Must1(_ any, err error) {
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Must2(_ any, _ any, err error) {
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,13 @@ type ServerConnError struct {
|
|||
Cause error
|
||||
}
|
||||
|
||||
func (e *ServerConnError) Close() error {
|
||||
if tcpConn, ok := e.Conn.(*net.TCPConn); ok {
|
||||
tcpConn.SetLinger(0)
|
||||
}
|
||||
return e.Conn.Close()
|
||||
}
|
||||
|
||||
func (e *ServerConnError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
|
|
@ -14,6 +14,18 @@ const (
|
|||
PacketLengthBufferSize = 2
|
||||
)
|
||||
|
||||
const (
|
||||
// NonceSize
|
||||
// crypto/cipher.gcmStandardNonceSize
|
||||
// golang.org/x/crypto/chacha20poly1305.NonceSize
|
||||
NonceSize = 12
|
||||
|
||||
// Overhead
|
||||
// crypto/cipher.gcmTagSize
|
||||
// golang.org/x/crypto/chacha20poly1305.Overhead
|
||||
Overhead = 16
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
upstream io.Reader
|
||||
cipher cipher.AEAD
|
||||
|
@ -27,8 +39,8 @@ func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reade
|
|||
return &Reader{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
|
||||
nonce: make([]byte, NonceSize),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,7 +74,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
n += int64(writeN)
|
||||
}
|
||||
for {
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -73,7 +85,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -92,7 +104,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
|
||||
func (r *Reader) readInternal() (err error) {
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -103,7 +115,7 @@ func (r *Reader) readInternal() (err error) {
|
|||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -138,7 +150,7 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
|||
r.index += n
|
||||
return
|
||||
}
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
start := PacketLengthBufferSize + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
@ -149,7 +161,7 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
end := length + Overhead
|
||||
|
||||
if len(b) >= end {
|
||||
data := b[:end]
|
||||
|
@ -202,6 +214,59 @@ func (r *Reader) Cached() int {
|
|||
return r.cached
|
||||
}
|
||||
|
||||
func (r *Reader) CachedSlice() []byte {
|
||||
return r.buffer[r.index : r.index+r.cached]
|
||||
}
|
||||
|
||||
func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error {
|
||||
_, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + Overhead
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = length
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadWithLength(length uint16) error {
|
||||
end := length + Overhead
|
||||
_, err := io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = int(length)
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadChunk(chunk []byte) error {
|
||||
bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = len(bb)
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
|
@ -214,7 +279,7 @@ func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Write
|
|||
return &Writer{
|
||||
upstream: upstream,
|
||||
cipher: cipher,
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2),
|
||||
buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
|
||||
nonce: make([]byte, cipher.NonceSize()),
|
||||
maxPacketSize: maxPacketSize,
|
||||
}
|
||||
|
@ -244,7 +309,7 @@ func (w *Writer) SetWriter(writer io.Writer) {
|
|||
|
||||
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
offset := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize])
|
||||
if readErr != nil {
|
||||
return 0, readErr
|
||||
|
@ -271,7 +336,7 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
|||
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
|
||||
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.nonce)
|
||||
offset := w.cipher.Overhead() + PacketLengthBufferSize
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
|
||||
increaseNonce(w.nonce)
|
||||
_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
|
||||
|
@ -288,11 +353,17 @@ func (w *Writer) Buffer() *buf.Buffer {
|
|||
return buf.With(w.buffer)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) {
|
||||
bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil)
|
||||
buffer.Extend(len(bb))
|
||||
increaseNonce(w.nonce)
|
||||
}
|
||||
|
||||
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
upstream: w,
|
||||
reversed: reversed,
|
||||
data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()],
|
||||
data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead],
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -340,7 +411,7 @@ func (w *BufferedWriter) Flush() error {
|
|||
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
|
||||
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
|
||||
|
|
|
@ -154,7 +154,7 @@ func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
|||
c := m.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
buffer.Extend(Overhead)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
var ErrBadHeader = E.New("bad header")
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
|
@ -83,17 +85,25 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
}
|
||||
|
||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
_salt := buf.Make(s.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead)
|
||||
defer runtime.KeepAlive(_header)
|
||||
header := common.Dup(_header)
|
||||
|
||||
_, err := io.ReadFull(conn, salt)
|
||||
n, err := conn.Read(header)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read salt")
|
||||
return E.Cause(err, "read header")
|
||||
} else if n < len(header) {
|
||||
return ErrBadHeader
|
||||
}
|
||||
|
||||
key := Kdf(s.key, salt, s.keySaltLength)
|
||||
key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength)
|
||||
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
|
||||
|
||||
err = reader.ReadWithLengthChunk(header[s.keySaltLength:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -250,6 +260,6 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
|
|||
c := w.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
buffer.Extend(Overhead)
|
||||
return w.PacketConn.WritePacket(buffer, w.source)
|
||||
}
|
||||
|
|
|
@ -30,22 +30,12 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
HeaderTypeClient = 0
|
||||
HeaderTypeServer = 1
|
||||
MaxPaddingLength = 900
|
||||
PacketNonceSize = 24
|
||||
MaxPacketSize = 65535
|
||||
)
|
||||
|
||||
const (
|
||||
// crypto/cipher.gcmStandardNonceSize
|
||||
// golang.org/x/crypto/chacha20poly1305.NonceSize
|
||||
nonceSize = 12
|
||||
|
||||
// Overhead
|
||||
// crypto/cipher.gcmTagSize
|
||||
// golang.org/x/crypto/chacha20poly1305.Overhead
|
||||
overhead = 16
|
||||
HeaderTypeClient = 0
|
||||
HeaderTypeServer = 1
|
||||
MaxPaddingLength = 900
|
||||
PacketNonceSize = 24
|
||||
MaxPacketSize = 65535
|
||||
RequestHeaderFixedChunkLength = 1 + 8 + 2
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -259,38 +249,32 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
header.Write(salt)
|
||||
c.writeExtendedIdentityHeaders(header, salt)
|
||||
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write destination")
|
||||
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
|
||||
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
|
||||
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
var paddingLen int
|
||||
if len(payload) == 0 {
|
||||
paddingLen = rand.Intn(MaxPaddingLength + 1)
|
||||
}
|
||||
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload)
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
|
||||
writer.WriteChunk(header, fixedLengthBuffer.Slice())
|
||||
runtime.KeepAlive(_fixedLengthBuffer)
|
||||
|
||||
if len(payload) > 0 {
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
_variableLengthBuffer := buf.Make(variableLengthHeaderLen)
|
||||
variableLengthBuffer := buf.With(common.Dup(_variableLengthBuffer))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination))
|
||||
common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
|
||||
if paddingLen > 0 {
|
||||
common.Must1(io.CopyN(variableLengthBuffer, c.secureRNG, int64(paddingLen)))
|
||||
} else {
|
||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = io.CopyN(bufferedWriter, c.secureRNG, int64(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding")
|
||||
}
|
||||
common.Must1(variableLengthBuffer.Write(payload))
|
||||
}
|
||||
writer.WriteChunk(header, variableLengthBuffer.Slice())
|
||||
runtime.KeepAlive(_variableLengthBuffer)
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
err := writer.BufferedWriter(header.Len()).Flush()
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
@ -325,6 +309,11 @@ func (c *clientConn) readResponse() error {
|
|||
)
|
||||
runtime.KeepAlive(key)
|
||||
|
||||
err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2))
|
||||
if err != nil {
|
||||
return E.Cause(err, "read response fixed length chunk")
|
||||
}
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -356,6 +345,17 @@ func (c *clientConn) readResponse() error {
|
|||
}
|
||||
runtime.KeepAlive(_requestSalt)
|
||||
|
||||
var length uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = reader.ReadWithLength(length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.requestSalt = nil
|
||||
c.reader = reader
|
||||
|
||||
|
@ -481,11 +481,11 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
|||
}
|
||||
if c.udpCipher != nil {
|
||||
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.udpCipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
} else {
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.session.cipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
|
@ -505,7 +505,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
buffer.Truncate(buffer.Len() - c.udpCipher.Overhead())
|
||||
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
||||
} else {
|
||||
packetHeader = buffer.To(aes.BlockSize)
|
||||
c.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||
|
@ -536,7 +536,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
if err != nil {
|
||||
return M.Socksaddr{}, E.Cause(err, "decrypt packet")
|
||||
}
|
||||
buffer.Truncate(buffer.Len() - remoteCipher.Overhead())
|
||||
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
||||
}
|
||||
|
||||
var headerType byte
|
||||
|
@ -625,9 +625,9 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
destination := M.SocksaddrFromNet(addr)
|
||||
var overHead int
|
||||
if c.udpCipher != nil {
|
||||
overHead = PacketNonceSize + c.udpCipher.Overhead()
|
||||
overHead = PacketNonceSize + shadowaead.Overhead
|
||||
} else {
|
||||
overHead = c.session.cipher.Overhead()
|
||||
overHead = shadowaead.Overhead
|
||||
}
|
||||
overHead += 16 // packet header
|
||||
pskLen := len(c.pskList)
|
||||
|
@ -686,11 +686,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
}
|
||||
if c.udpCipher != nil {
|
||||
c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.udpCipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
} else {
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||
buffer.Extend(c.session.cipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
c.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
err = common.Error(c.Write(buffer.Bytes()))
|
||||
|
|
|
@ -28,7 +28,10 @@ import (
|
|||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
|
||||
var ErrNoPadding = E.New("bad request: missing payload or padding")
|
||||
var (
|
||||
ErrNoPadding = E.New("bad request: missing payload or padding")
|
||||
ErrBadPadding = E.New("bad request: damaged padding")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
|
@ -108,12 +111,17 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
}
|
||||
|
||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
requestSalt := buf.Make(s.keySaltLength)
|
||||
_, err := io.ReadFull(conn, requestSalt)
|
||||
header := buf.Make(s.keySaltLength + shadowaead.Overhead + RequestHeaderFixedChunkLength)
|
||||
|
||||
n, err := conn.Read(header)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request salt")
|
||||
return E.Cause(err, "read header")
|
||||
} else if n < len(header) {
|
||||
return shadowaead.ErrBadHeader
|
||||
}
|
||||
|
||||
requestSalt := header[:s.keySaltLength]
|
||||
|
||||
if !s.replayFilter.Check(requestSalt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
|
@ -126,7 +134,12 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
)
|
||||
runtime.KeepAlive(requestKey)
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
err = reader.ReadChunk(header[s.keySaltLength:])
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request fixed length chunk")
|
||||
}
|
||||
|
||||
headerType, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return E.Cause(err, "read header")
|
||||
}
|
||||
|
@ -138,7 +151,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
var epoch uint64
|
||||
err = binary.Read(reader, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read timestamp")
|
||||
return err
|
||||
}
|
||||
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
|
@ -146,15 +159,30 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
var length uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = reader.ReadWithLength(length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read destination")
|
||||
return err
|
||||
}
|
||||
|
||||
var paddingLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read padding length")
|
||||
return err
|
||||
}
|
||||
|
||||
if uint16(reader.Cached()) < paddingLen {
|
||||
return ErrNoPadding
|
||||
}
|
||||
|
||||
if paddingLen > 0 {
|
||||
|
@ -201,21 +229,23 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|||
runtime.KeepAlive(key)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must1(bufferedWriter.Write(c.requestSalt[:]))
|
||||
_headerFixedChunk := buf.Make(1 + 8 + c.keySaltLength + 2)
|
||||
headerFixedChunk := buf.With(common.Dup(_headerFixedChunk))
|
||||
common.Must(headerFixedChunk.WriteByte(HeaderTypeServer))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must1(headerFixedChunk.Write(c.requestSalt))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload))))
|
||||
|
||||
writer.WriteChunk(header, headerFixedChunk.Slice())
|
||||
runtime.KeepAlive(_headerFixedChunk)
|
||||
c.requestSalt = nil
|
||||
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
writer.WriteChunk(header, payload)
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
err = writer.BufferedWriter(header.Len()).Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -287,7 +317,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
|
|||
return E.Cause(err, "decrypt packet header")
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
buffer.Truncate(buffer.Len() - s.udpCipher.Overhead())
|
||||
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
||||
} else {
|
||||
packetHeader = buffer.To(aes.BlockSize)
|
||||
s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||
|
@ -332,7 +362,7 @@ process:
|
|||
err = E.Cause(err, "decrypt packet")
|
||||
goto returnErr
|
||||
}
|
||||
buffer.Truncate(buffer.Len() - session.remoteCipher.Overhead())
|
||||
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
||||
}
|
||||
|
||||
var headerType byte
|
||||
|
@ -421,11 +451,11 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
|
|||
|
||||
if w.udpCipher != nil {
|
||||
w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||
buffer.Extend(w.udpCipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
} else {
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||
buffer.Extend(w.session.cipher.Overhead())
|
||||
buffer.Extend(shadowaead.Overhead)
|
||||
w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
return w.PacketConn.WritePacket(buffer, w.session.remoteAddr)
|
||||
|
|
|
@ -91,22 +91,21 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta
|
|||
}
|
||||
|
||||
func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
requestSalt := make([]byte, s.keySaltLength)
|
||||
_, err := io.ReadFull(conn, requestSalt)
|
||||
requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
|
||||
n, err := conn.Read(requestHeader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request salt")
|
||||
return err
|
||||
} else if n < len(requestHeader) {
|
||||
return shadowaead.ErrBadHeader
|
||||
}
|
||||
|
||||
requestSalt := requestHeader[:s.keySaltLength]
|
||||
if !s.replayFilter.Check(requestSalt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
|
||||
var _eiHeader [aes.BlockSize]byte
|
||||
eiHeader := common.Dup(_eiHeader[:])
|
||||
_, err = io.ReadFull(conn, eiHeader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read extended identity header")
|
||||
}
|
||||
copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
|
||||
|
||||
keyMaterial := buf.Make(s.keySaltLength * 2)
|
||||
copy(keyMaterial, s.psk)
|
||||
|
@ -125,6 +124,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
} else {
|
||||
return E.New("invalid request")
|
||||
}
|
||||
runtime.KeepAlive(_eiHeader)
|
||||
|
||||
requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
|
||||
reader := shadowaead.NewReader(
|
||||
|
@ -132,7 +132,11 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
s.constructor(common.Dup(requestKey)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
runtime.KeepAlive(requestSalt)
|
||||
|
||||
err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
|
@ -152,6 +156,16 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
if diff > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read length")
|
||||
}
|
||||
|
||||
err = reader.ReadWithLength(length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
|
@ -164,7 +178,9 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
return E.Cause(err, "read padding length")
|
||||
}
|
||||
|
||||
if paddingLen > 0 {
|
||||
if reader.Cached() < int(paddingLen) {
|
||||
return ErrBadPadding
|
||||
} else if paddingLen > 0 {
|
||||
err = reader.Discard(int(paddingLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "discard padding")
|
||||
|
@ -257,7 +273,7 @@ process:
|
|||
err = E.Cause(err, "decrypt packet")
|
||||
goto returnErr
|
||||
}
|
||||
buffer.Truncate(buffer.Len() - session.remoteCipher.Overhead())
|
||||
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
|
||||
}
|
||||
|
||||
var headerType byte
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
|
@ -40,7 +41,8 @@ func TestMultiService(t *testing.T) {
|
|||
go func() {
|
||||
err := multiService.NewConnection(context.Background(), serverConn, M.Metadata{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
serverConn.Close()
|
||||
t.Error(E.Cause(err, "server"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
|
49
protocol/shadowsocks/shadowaead_2022/service_test.go
Normal file
49
protocol/shadowsocks/shadowaead_2022/service_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package shadowaead_2022_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022"
|
||||
)
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
method := "2022-blake3-aes-128-gcm"
|
||||
var psk [16]byte
|
||||
random.Default.Read(psk[:])
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
service, err := shadowaead_2022.NewService(method, psk[:], "", random.Default, 500, &multiHandler{t, &wg})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, "", random.Default)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wg.Add(1)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer common.Close(serverConn, clientConn)
|
||||
go func() {
|
||||
err := service.NewConnection(context.Background(), serverConn, M.Metadata{})
|
||||
if err != nil {
|
||||
serverConn.Close()
|
||||
t.Error(E.Cause(err, "server"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
_, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue