Draft: Encrypted Protocol Extension

This commit is contained in:
世界 2022-06-15 16:19:52 +08:00
parent cd8ec2833c
commit e3a6eb8580
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 341 additions and 53 deletions

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks
go 1.18
require (
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
lukechampine.com/blake3 v1.1.7
)

4
go.sum
View file

@ -1,8 +1,8 @@
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc h1:AdNTzzSw6SCZI71GB+Am7cr+oUDUrBUaOi17FxDtNMw=
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4=
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507 h1:rMYMyB6N0ARFg0bwgG1Ahl+h0HCXO74yzT8PYvxOuPs=
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d h1:Zu/JngovGLVi6t2J3nmAf3AoTDwuzw85YZ3b9o4yU7s=

View file

@ -197,6 +197,12 @@ func (r *Reader) Discard(n int) error {
}
}
func (r *Reader) Buffer() *buf.Buffer {
buffer := buf.With(r.buffer)
buffer.Resize(r.index, r.cached)
return buffer
}
func (r *Reader) Cached() int {
return r.cached
}
@ -243,7 +249,7 @@ func (r *Reader) ReadWithLength(length uint16) error {
return nil
}
func (r *Reader) ReadChunk(chunk []byte) error {
func (r *Reader) ReadExternalChunk(chunk []byte) error {
bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
if err != nil {
return err
@ -254,6 +260,16 @@ func (r *Reader) ReadChunk(chunk []byte) error {
return nil
}
func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error {
bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil)
if err != nil {
return err
}
increaseNonce(r.nonce)
buffer.Extend(len(bb))
return nil
}
type Writer struct {
upstream io.Writer
cipher cipher.AEAD

View file

@ -0,0 +1,169 @@
package shadowaead_2022
import (
"encoding/binary"
"io"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
)
const (
recordTypeHandshake = 22
recordTypeApplicationData = 23
tlsVersion10 = 0x0301
tlsVersion11 = 0x0302
tlsVersion12 = 0x0303
tlsVersion13 = 0x0304
tlsEncryptedLengthChunkLength = 5 + shadowaead.Overhead
)
func isTLSHandshake(payload []byte) bool {
if len(payload) < 5 {
return false
}
if payload[0] != recordTypeHandshake {
return false
}
tlsVersion := binary.BigEndian.Uint16(payload[1:])
if tlsVersion < tlsVersion10 || tlsVersion > tlsVersion13 {
return false
}
return true
}
func readTLSChunkEnd(payload []byte) int {
pLen := len(payload)
index := 0
for index < pLen {
if pLen-index < 5 {
break
}
dataLen := binary.BigEndian.Uint16(payload[index+3 : index+5])
nextIndex := index + 5 + int(dataLen)
if nextIndex > pLen {
return index
}
index = nextIndex
}
return index
}
type TLSEncryptedStreamReader struct {
upstream *shadowaead.Reader
raw io.Reader
buffer *buf.Buffer
}
func NewTLSEncryptedStreamReader(upstream *shadowaead.Reader) *TLSEncryptedStreamReader {
var reader TLSEncryptedStreamReader
reader.upstream = upstream
reader.raw = upstream.Upstream().(io.Reader)
reader.buffer = upstream.Buffer()
return &reader
}
func (r *TLSEncryptedStreamReader) Read(p []byte) (n int, err error) {
if !r.buffer.IsEmpty() {
return r.buffer.Read(p)
}
data := r.buffer.Slice()
_, err = io.ReadFull(r.raw, data[:tlsEncryptedLengthChunkLength])
if err != nil {
return
}
r.buffer.FullReset()
err = r.upstream.ReadChunk(r.buffer, data[:tlsEncryptedLengthChunkLength])
if err != nil {
return
}
recordType := data[0]
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
if recordType == recordTypeApplicationData {
_, err = r.buffer.ReadFullFrom(r.raw, recordLen)
if err != nil {
return
}
} else {
_, err = io.ReadFull(r.raw, data[5:5+recordLen+shadowaead.Overhead])
if err != nil {
return
}
err = r.upstream.ReadChunk(r.buffer, data[5:5+recordLen+shadowaead.Overhead])
if err != nil {
return
}
}
return r.buffer.Read(p)
}
type TLSEncryptedStreamWriter struct {
upstream *shadowaead.Writer
raw io.Writer
buffer *buf.Buffer
pipeIn *io.PipeReader
pipeOut *io.PipeWriter
}
func NewTLSEncryptedStreamWriter(upstream *shadowaead.Writer) *TLSEncryptedStreamWriter {
var writer TLSEncryptedStreamWriter
writer.upstream = upstream
writer.raw = upstream.Upstream().(io.Writer)
writer.buffer = upstream.Buffer()
writer.pipeIn, writer.pipeOut = io.Pipe()
go writer.loopOut()
return &writer
}
func (w *TLSEncryptedStreamWriter) Write(p []byte) (n int, err error) {
return w.pipeOut.Write(p)
}
func (w *TLSEncryptedStreamWriter) loopOut() {
data := w.buffer.Slice()
var err error
for {
_, err = io.ReadFull(w.pipeIn, data[:5])
if err != nil {
break
}
recordType := data[0]
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
w.buffer.FullReset()
w.upstream.WriteChunk(w.buffer, data[:5])
if recordType != recordTypeApplicationData {
_, err = io.ReadFull(w.pipeIn, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen])
if err != nil {
return
}
w.upstream.WriteChunk(w.buffer, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen])
} else {
_, err = w.buffer.ReadFullFrom(w.pipeIn, recordLen)
if err != nil {
break
}
}
_, err = w.raw.Write(w.buffer.Bytes())
if err != nil {
break
}
}
w.pipeIn.CloseWithError(err)
}
func (w *TLSEncryptedStreamWriter) Close() error {
return common.Close(
w.upstream,
w.pipeOut,
)
}
func (w *TLSEncryptedStreamWriter) Upstream() any {
return w.upstream
}

View file

@ -39,6 +39,9 @@ const (
PacketNonceSize = 24
MaxPacketSize = 65535
RequestHeaderFixedChunkLength = 1 + 8 + 2
HeaderTypeClientEncrypted = 10
HeaderTypeServerEncrypted = 11
)
var (
@ -58,7 +61,7 @@ var List = []string{
"2022-blake3-chacha20-poly1305",
}
func NewWithPassword(method string, password string) (shadowsocks.Method, error) {
func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) {
var pskList [][]byte
if password == "" {
return nil, ErrMissingPSK
@ -72,10 +75,10 @@ func NewWithPassword(method string, password string) (shadowsocks.Method, error)
}
pskList[i] = kb
}
return New(method, pskList)
return New(method, pskList, options...)
}
func New(method string, pskList [][]byte) (shadowsocks.Method, error) {
func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) {
m := &Method{
name: method,
replayFilter: replay.NewSimple(60 * time.Second),
@ -134,6 +137,9 @@ func New(method string, pskList [][]byte) (shadowsocks.Method, error) {
}
m.pskList = pskList
for _, option := range options {
option(m)
}
return m, nil
}
@ -162,15 +168,16 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci
}
type Method struct {
name string
keySaltLength int
constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error)
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
pskList [][]byte
pskHash []byte
replayFilter replay.Filter
name string
keySaltLength int
constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error)
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
pskList [][]byte
pskHash []byte
replayFilter replay.Filter
encryptedProtocolExtension bool
}
func (m *Method) Name() string {
@ -203,8 +210,8 @@ type clientConn struct {
net.Conn
destination M.Socksaddr
requestSalt []byte
reader *shadowaead.Reader
writer *shadowaead.Writer
reader io.Reader
writer io.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
@ -239,6 +246,13 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
}
func (c *clientConn) writeRequest(payload []byte) error {
var headerType byte
if c.encryptedProtocolExtension && isTLSHandshake(payload) {
headerType = HeaderTypeClientEncrypted
} else {
headerType = HeaderTypeClient
}
salt := make([]byte, c.keySaltLength)
common.Must1(io.ReadFull(rand.Reader, salt))
@ -264,13 +278,21 @@ func (c *clientConn) writeRequest(payload []byte) error {
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
common.Must(fixedLengthBuffer.WriteByte(headerType))
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
var paddingLen int
if len(payload) < MaxPaddingLength {
paddingLen = mRand.Intn(MaxPaddingLength) + 1
}
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload)
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen
var payloadLen int
switch headerType {
case HeaderTypeClient:
payloadLen = len(payload)
case HeaderTypeClientEncrypted:
payloadLen = readTLSChunkEnd(payload)
}
variableLengthHeaderLen += payloadLen
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
writer.WriteChunk(header, fixedLengthBuffer.Slice())
common.KeepAlive(_fixedLengthBuffer)
@ -282,8 +304,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
if paddingLen > 0 {
variableLengthBuffer.Extend(paddingLen)
}
if len(payload) > 0 {
common.Must1(variableLengthBuffer.Write(payload))
if payloadLen > 0 {
common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
}
writer.WriteChunk(header, variableLengthBuffer.Slice())
common.KeepAlive(_variableLengthBuffer)
@ -295,7 +317,18 @@ func (c *clientConn) writeRequest(payload []byte) error {
}
c.requestSalt = salt
c.writer = writer
if headerType == HeaderTypeClient {
c.writer = writer
} else if headerType == HeaderTypeClientEncrypted {
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
if payloadLen < len(payload) {
_, err = encryptedWriter.Write(payload[payloadLen:])
if err != nil {
return err
}
}
c.writer = encryptedWriter
}
return nil
}
@ -346,7 +379,7 @@ func (c *clientConn) readResponse() error {
if err != nil {
return err
}
if headerType != HeaderTypeServer {
if headerType != HeaderTypeServer && headerType != HeaderTypeServerEncrypted {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
}
@ -373,6 +406,7 @@ func (c *clientConn) readResponse() error {
}
requestSalt.Release()
common.KeepAlive(_requestSalt)
c.requestSalt = nil
var length uint16
err = binary.Read(reader, binary.BigEndian, &length)
@ -384,10 +418,11 @@ func (c *clientConn) readResponse() error {
if err != nil {
return err
}
c.requestSalt = nil
c.reader = reader
if headerType == HeaderTypeServer {
c.reader = reader
} else if headerType == HeaderTypeServerEncrypted {
c.reader = NewTLSEncryptedStreamReader(reader)
}
return nil
}
@ -402,7 +437,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if err = c.readResponse(); err != nil {
return
}
return c.reader.WriteTo(w)
return bufio.Copy(w, c.reader)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
@ -420,13 +455,21 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return bufio.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
return bufio.Copy(c.writer, r)
}
func (c *clientConn) Upstream() any {
return c.Conn
}
func (c *clientConn) Close() error {
return common.Close(
c.Conn,
c.reader,
c.writer,
)
}
type clientPacketConn struct {
*Method
net.Conn

View file

@ -0,0 +1,9 @@
package shadowaead_2022
type MethodOption func(*Method)
func MethodOptionEncryptedProtocolExtension() MethodOption {
return func(method *Method) {
method.encryptedProtocolExtension = true
}
}

View file

@ -153,7 +153,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
)
common.KeepAlive(requestKey)
err = reader.ReadChunk(header[s.keySaltLength:])
err = reader.ReadExternalChunk(header[s.keySaltLength:])
if err != nil {
return err
}
@ -163,7 +163,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
@ -213,15 +213,24 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return ErrNoPadding
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, &serverConn{
protocolConn := &serverConn{
Service: s,
Conn: conn,
uPSK: s.psk,
reader: reader,
headerType: headerType,
requestSalt: requestSalt,
}, metadata)
}
switch headerType {
case HeaderTypeClient:
protocolConn.reader = reader
case HeaderTypeClientEncrypted:
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(ctx, protocolConn, metadata)
}
type serverConn struct {
@ -229,8 +238,9 @@ type serverConn struct {
net.Conn
uPSK []byte
access sync.Mutex
reader *shadowaead.Reader
writer *shadowaead.Writer
headerType byte
reader io.Reader
writer io.Writer
requestSalt []byte
}
@ -259,20 +269,31 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
salt.Release()
common.KeepAlive(_salt)
var headerType byte
var payloadLen int
switch c.headerType {
case HeaderTypeClient:
headerType = HeaderTypeServer
payloadLen = len(payload)
case HeaderTypeClientEncrypted:
headerType = HeaderTypeServerEncrypted
payloadLen = readTLSChunkEnd(payload)
}
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
headerFixedChunk := common.Dup(_headerFixedChunk)
common.Must(headerFixedChunk.WriteByte(HeaderTypeServer))
common.Must(headerFixedChunk.WriteByte(headerType))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
common.Must1(headerFixedChunk.Write(c.requestSalt))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload))))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
writer.WriteChunk(header, headerFixedChunk.Slice())
headerFixedChunk.Release()
common.KeepAlive(_headerFixedChunk)
c.requestSalt = nil
if len(payload) > 0 {
writer.WriteChunk(header, payload)
if payloadLen > 0 {
writer.WriteChunk(header, payload[:payloadLen])
}
err = writer.BufferedWriter(header.Len()).Flush()
@ -280,7 +301,20 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
return
}
c.writer = writer
switch headerType {
case HeaderTypeServer:
c.writer = writer
case HeaderTypeServerEncrypted:
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
if payloadLen < len(payload) {
_, err = encryptedWriter.Write(payload[payloadLen:])
if err != nil {
return
}
}
c.writer = encryptedWriter
}
n = len(payload)
return
}
@ -302,11 +336,19 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return bufio.ReadFrom0(c, r)
}
return c.writer.ReadFrom(r)
return bufio.Copy(c.writer, r)
}
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
return bufio.Copy(w, c.reader)
}
func (c *serverConn) Close() error {
return common.Close(
c.Conn,
c.reader,
c.writer,
)
}
func (c *serverConn) Upstream() any {

View file

@ -162,7 +162,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
MaxPacketSize,
)
err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
if err != nil {
return err
}
@ -172,7 +172,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient {
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
@ -222,15 +222,24 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
userCtx.Context = ctx
userCtx.User = user
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(&userCtx, &serverConn{
protocolConn := &serverConn{
Service: s.Service,
Conn: conn,
uPSK: uPSK,
reader: reader,
headerType: headerType,
requestSalt: requestSalt,
}, metadata)
}
switch headerType {
case HeaderTypeClient:
protocolConn.reader = reader
case HeaderTypeClientEncrypted:
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.handler.NewConnection(&userCtx, protocolConn, metadata)
}
func (s *MultiService[U]) WriteIsThreadUnsafe() {