From e3a6eb8580bd557367ae8d471dfc1f5b38daa8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 15 Jun 2022 16:19:52 +0800 Subject: [PATCH] Draft: Encrypted Protocol Extension --- go.mod | 2 +- go.sum | 4 +- shadowaead/aead.go | 18 ++- shadowaead_2022/encrypted_stream.go | 169 ++++++++++++++++++++++++++++ shadowaead_2022/protocol.go | 95 +++++++++++----- shadowaead_2022/protocol_option.go | 9 ++ shadowaead_2022/service.go | 74 +++++++++--- shadowaead_2022/service_multi.go | 23 ++-- 8 files changed, 341 insertions(+), 53 deletions(-) create mode 100644 shadowaead_2022/encrypted_stream.go create mode 100644 shadowaead_2022/protocol_option.go diff --git a/go.mod b/go.mod index 55e4fd4..758b1d7 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks go 1.18 require ( - github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc + github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507 golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e lukechampine.com/blake3 v1.1.7 ) diff --git a/go.sum b/go.sum index 0f3fc2c..b0ac630 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc h1:AdNTzzSw6SCZI71GB+Am7cr+oUDUrBUaOi17FxDtNMw= -github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4= +github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507 h1:rMYMyB6N0ARFg0bwgG1Ahl+h0HCXO74yzT8PYvxOuPs= +github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d h1:Zu/JngovGLVi6t2J3nmAf3AoTDwuzw85YZ3b9o4yU7s= diff --git a/shadowaead/aead.go b/shadowaead/aead.go index 1e0fe49..8e07cdc 100644 --- a/shadowaead/aead.go +++ b/shadowaead/aead.go @@ -197,6 +197,12 @@ func (r *Reader) Discard(n int) error { } } +func (r *Reader) Buffer() *buf.Buffer { + buffer := buf.With(r.buffer) + buffer.Resize(r.index, r.cached) + return buffer +} + func (r *Reader) Cached() int { return r.cached } @@ -243,7 +249,7 @@ func (r *Reader) ReadWithLength(length uint16) error { return nil } -func (r *Reader) ReadChunk(chunk []byte) error { +func (r *Reader) ReadExternalChunk(chunk []byte) error { bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) if err != nil { return err @@ -254,6 +260,16 @@ func (r *Reader) ReadChunk(chunk []byte) error { return nil } +func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error { + bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + buffer.Extend(len(bb)) + return nil +} + type Writer struct { upstream io.Writer cipher cipher.AEAD diff --git a/shadowaead_2022/encrypted_stream.go b/shadowaead_2022/encrypted_stream.go new file mode 100644 index 0000000..1b4fab2 --- /dev/null +++ b/shadowaead_2022/encrypted_stream.go @@ -0,0 +1,169 @@ +package shadowaead_2022 + +import ( + "encoding/binary" + "io" + + "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" +) + +const ( + recordTypeHandshake = 22 + recordTypeApplicationData = 23 + + tlsVersion10 = 0x0301 + tlsVersion11 = 0x0302 + tlsVersion12 = 0x0303 + tlsVersion13 = 0x0304 + + tlsEncryptedLengthChunkLength = 5 + shadowaead.Overhead +) + +func isTLSHandshake(payload []byte) bool { + if len(payload) < 5 { + return false + } + if payload[0] != recordTypeHandshake { + return false + } + tlsVersion := binary.BigEndian.Uint16(payload[1:]) + if tlsVersion < tlsVersion10 || tlsVersion > tlsVersion13 { + return false + } + return true +} + +func readTLSChunkEnd(payload []byte) int { + pLen := len(payload) + index := 0 + for index < pLen { + if pLen-index < 5 { + break + } + dataLen := binary.BigEndian.Uint16(payload[index+3 : index+5]) + nextIndex := index + 5 + int(dataLen) + if nextIndex > pLen { + return index + } + index = nextIndex + } + return index +} + +type TLSEncryptedStreamReader struct { + upstream *shadowaead.Reader + raw io.Reader + buffer *buf.Buffer +} + +func NewTLSEncryptedStreamReader(upstream *shadowaead.Reader) *TLSEncryptedStreamReader { + var reader TLSEncryptedStreamReader + reader.upstream = upstream + reader.raw = upstream.Upstream().(io.Reader) + reader.buffer = upstream.Buffer() + return &reader +} + +func (r *TLSEncryptedStreamReader) Read(p []byte) (n int, err error) { + if !r.buffer.IsEmpty() { + return r.buffer.Read(p) + } + data := r.buffer.Slice() + _, err = io.ReadFull(r.raw, data[:tlsEncryptedLengthChunkLength]) + if err != nil { + return + } + r.buffer.FullReset() + err = r.upstream.ReadChunk(r.buffer, data[:tlsEncryptedLengthChunkLength]) + if err != nil { + return + } + recordType := data[0] + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + if recordType == recordTypeApplicationData { + _, err = r.buffer.ReadFullFrom(r.raw, recordLen) + if err != nil { + return + } + } else { + _, err = io.ReadFull(r.raw, data[5:5+recordLen+shadowaead.Overhead]) + if err != nil { + return + } + err = r.upstream.ReadChunk(r.buffer, data[5:5+recordLen+shadowaead.Overhead]) + if err != nil { + return + } + } + return r.buffer.Read(p) +} + +type TLSEncryptedStreamWriter struct { + upstream *shadowaead.Writer + raw io.Writer + buffer *buf.Buffer + pipeIn *io.PipeReader + pipeOut *io.PipeWriter +} + +func NewTLSEncryptedStreamWriter(upstream *shadowaead.Writer) *TLSEncryptedStreamWriter { + var writer TLSEncryptedStreamWriter + writer.upstream = upstream + writer.raw = upstream.Upstream().(io.Writer) + writer.buffer = upstream.Buffer() + writer.pipeIn, writer.pipeOut = io.Pipe() + go writer.loopOut() + return &writer +} + +func (w *TLSEncryptedStreamWriter) Write(p []byte) (n int, err error) { + return w.pipeOut.Write(p) +} + +func (w *TLSEncryptedStreamWriter) loopOut() { + data := w.buffer.Slice() + var err error + for { + _, err = io.ReadFull(w.pipeIn, data[:5]) + if err != nil { + break + } + recordType := data[0] + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + + w.buffer.FullReset() + w.upstream.WriteChunk(w.buffer, data[:5]) + + if recordType != recordTypeApplicationData { + _, err = io.ReadFull(w.pipeIn, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen]) + if err != nil { + return + } + w.upstream.WriteChunk(w.buffer, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen]) + } else { + _, err = w.buffer.ReadFullFrom(w.pipeIn, recordLen) + if err != nil { + break + } + } + + _, err = w.raw.Write(w.buffer.Bytes()) + if err != nil { + break + } + } + w.pipeIn.CloseWithError(err) +} + +func (w *TLSEncryptedStreamWriter) Close() error { + return common.Close( + w.upstream, + w.pipeOut, + ) +} + +func (w *TLSEncryptedStreamWriter) Upstream() any { + return w.upstream +} diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go index 4cddbdb..df715b7 100644 --- a/shadowaead_2022/protocol.go +++ b/shadowaead_2022/protocol.go @@ -39,6 +39,9 @@ const ( PacketNonceSize = 24 MaxPacketSize = 65535 RequestHeaderFixedChunkLength = 1 + 8 + 2 + + HeaderTypeClientEncrypted = 10 + HeaderTypeServerEncrypted = 11 ) var ( @@ -58,7 +61,7 @@ var List = []string{ "2022-blake3-chacha20-poly1305", } -func NewWithPassword(method string, password string) (shadowsocks.Method, error) { +func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) { var pskList [][]byte if password == "" { return nil, ErrMissingPSK @@ -72,10 +75,10 @@ func NewWithPassword(method string, password string) (shadowsocks.Method, error) } pskList[i] = kb } - return New(method, pskList) + return New(method, pskList, options...) } -func New(method string, pskList [][]byte) (shadowsocks.Method, error) { +func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) { m := &Method{ name: method, replayFilter: replay.NewSimple(60 * time.Second), @@ -134,6 +137,9 @@ func New(method string, pskList [][]byte) (shadowsocks.Method, error) { } m.pskList = pskList + for _, option := range options { + option(m) + } return m, nil } @@ -162,15 +168,16 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci } type Method struct { - name string - keySaltLength int - constructor func(key []byte) (cipher.AEAD, error) - blockConstructor func(key []byte) (cipher.Block, error) - udpCipher cipher.AEAD - udpBlockCipher cipher.Block - pskList [][]byte - pskHash []byte - replayFilter replay.Filter + name string + keySaltLength int + constructor func(key []byte) (cipher.AEAD, error) + blockConstructor func(key []byte) (cipher.Block, error) + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + pskList [][]byte + pskHash []byte + replayFilter replay.Filter + encryptedProtocolExtension bool } func (m *Method) Name() string { @@ -203,8 +210,8 @@ type clientConn struct { net.Conn destination M.Socksaddr requestSalt []byte - reader *shadowaead.Reader - writer *shadowaead.Writer + reader io.Reader + writer io.Writer } func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { @@ -239,6 +246,13 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) } func (c *clientConn) writeRequest(payload []byte) error { + var headerType byte + if c.encryptedProtocolExtension && isTLSHandshake(payload) { + headerType = HeaderTypeClientEncrypted + } else { + headerType = HeaderTypeClient + } + salt := make([]byte, c.keySaltLength) common.Must1(io.ReadFull(rand.Reader, salt)) @@ -264,13 +278,21 @@ func (c *clientConn) writeRequest(payload []byte) error { var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) - common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) + common.Must(fixedLengthBuffer.WriteByte(headerType)) common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) var paddingLen int if len(payload) < MaxPaddingLength { paddingLen = mRand.Intn(MaxPaddingLength) + 1 } - variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload) + variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + var payloadLen int + switch headerType { + case HeaderTypeClient: + payloadLen = len(payload) + case HeaderTypeClientEncrypted: + payloadLen = readTLSChunkEnd(payload) + } + variableLengthHeaderLen += payloadLen common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) writer.WriteChunk(header, fixedLengthBuffer.Slice()) common.KeepAlive(_fixedLengthBuffer) @@ -282,8 +304,8 @@ func (c *clientConn) writeRequest(payload []byte) error { if paddingLen > 0 { variableLengthBuffer.Extend(paddingLen) } - if len(payload) > 0 { - common.Must1(variableLengthBuffer.Write(payload)) + if payloadLen > 0 { + common.Must1(variableLengthBuffer.Write(payload[:payloadLen])) } writer.WriteChunk(header, variableLengthBuffer.Slice()) common.KeepAlive(_variableLengthBuffer) @@ -295,7 +317,18 @@ func (c *clientConn) writeRequest(payload []byte) error { } c.requestSalt = salt - c.writer = writer + if headerType == HeaderTypeClient { + c.writer = writer + } else if headerType == HeaderTypeClientEncrypted { + encryptedWriter := NewTLSEncryptedStreamWriter(writer) + if payloadLen < len(payload) { + _, err = encryptedWriter.Write(payload[payloadLen:]) + if err != nil { + return err + } + } + c.writer = encryptedWriter + } return nil } @@ -346,7 +379,7 @@ func (c *clientConn) readResponse() error { if err != nil { return err } - if headerType != HeaderTypeServer { + if headerType != HeaderTypeServer && headerType != HeaderTypeServerEncrypted { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) } @@ -373,6 +406,7 @@ func (c *clientConn) readResponse() error { } requestSalt.Release() common.KeepAlive(_requestSalt) + c.requestSalt = nil var length uint16 err = binary.Read(reader, binary.BigEndian, &length) @@ -384,10 +418,11 @@ func (c *clientConn) readResponse() error { if err != nil { return err } - - c.requestSalt = nil - c.reader = reader - + if headerType == HeaderTypeServer { + c.reader = reader + } else if headerType == HeaderTypeServerEncrypted { + c.reader = NewTLSEncryptedStreamReader(reader) + } return nil } @@ -402,7 +437,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { if err = c.readResponse(); err != nil { return } - return c.reader.WriteTo(w) + return bufio.Copy(w, c.reader) } func (c *clientConn) Write(p []byte) (n int, err error) { @@ -420,13 +455,21 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { if c.writer == nil { return bufio.ReadFrom0(c, r) } - return c.writer.ReadFrom(r) + return bufio.Copy(c.writer, r) } func (c *clientConn) Upstream() any { return c.Conn } +func (c *clientConn) Close() error { + return common.Close( + c.Conn, + c.reader, + c.writer, + ) +} + type clientPacketConn struct { *Method net.Conn diff --git a/shadowaead_2022/protocol_option.go b/shadowaead_2022/protocol_option.go new file mode 100644 index 0000000..46e8e0b --- /dev/null +++ b/shadowaead_2022/protocol_option.go @@ -0,0 +1,9 @@ +package shadowaead_2022 + +type MethodOption func(*Method) + +func MethodOptionEncryptedProtocolExtension() MethodOption { + return func(method *Method) { + method.encryptedProtocolExtension = true + } +} diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go index 616a542..1ddddd3 100644 --- a/shadowaead_2022/service.go +++ b/shadowaead_2022/service.go @@ -153,7 +153,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M ) common.KeepAlive(requestKey) - err = reader.ReadChunk(header[s.keySaltLength:]) + err = reader.ReadExternalChunk(header[s.keySaltLength:]) if err != nil { return err } @@ -163,7 +163,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return E.Cause(err, "read header") } - if headerType != HeaderTypeClient { + if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) } @@ -213,15 +213,24 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return ErrNoPadding } - metadata.Protocol = "shadowsocks" - metadata.Destination = destination - return s.handler.NewConnection(ctx, &serverConn{ + protocolConn := &serverConn{ Service: s, Conn: conn, uPSK: s.psk, - reader: reader, + headerType: headerType, requestSalt: requestSalt, - }, metadata) + } + + switch headerType { + case HeaderTypeClient: + protocolConn.reader = reader + case HeaderTypeClientEncrypted: + protocolConn.reader = NewTLSEncryptedStreamReader(reader) + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(ctx, protocolConn, metadata) } type serverConn struct { @@ -229,8 +238,9 @@ type serverConn struct { net.Conn uPSK []byte access sync.Mutex - reader *shadowaead.Reader - writer *shadowaead.Writer + headerType byte + reader io.Reader + writer io.Writer requestSalt []byte } @@ -259,20 +269,31 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { salt.Release() common.KeepAlive(_salt) + var headerType byte + var payloadLen int + switch c.headerType { + case HeaderTypeClient: + headerType = HeaderTypeServer + payloadLen = len(payload) + case HeaderTypeClientEncrypted: + headerType = HeaderTypeServerEncrypted + payloadLen = readTLSChunkEnd(payload) + } + _headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2) headerFixedChunk := common.Dup(_headerFixedChunk) - common.Must(headerFixedChunk.WriteByte(HeaderTypeServer)) + common.Must(headerFixedChunk.WriteByte(headerType)) common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) common.Must1(headerFixedChunk.Write(c.requestSalt)) - common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload)))) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen))) writer.WriteChunk(header, headerFixedChunk.Slice()) headerFixedChunk.Release() common.KeepAlive(_headerFixedChunk) c.requestSalt = nil - if len(payload) > 0 { - writer.WriteChunk(header, payload) + if payloadLen > 0 { + writer.WriteChunk(header, payload[:payloadLen]) } err = writer.BufferedWriter(header.Len()).Flush() @@ -280,7 +301,20 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { return } - c.writer = writer + switch headerType { + case HeaderTypeServer: + c.writer = writer + case HeaderTypeServerEncrypted: + encryptedWriter := NewTLSEncryptedStreamWriter(writer) + if payloadLen < len(payload) { + _, err = encryptedWriter.Write(payload[payloadLen:]) + if err != nil { + return + } + } + c.writer = encryptedWriter + } + n = len(payload) return } @@ -302,11 +336,19 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { if c.writer == nil { return bufio.ReadFrom0(c, r) } - return c.writer.ReadFrom(r) + return bufio.Copy(c.writer, r) } func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { - return c.reader.WriteTo(w) + return bufio.Copy(w, c.reader) +} + +func (c *serverConn) Close() error { + return common.Close( + c.Conn, + c.reader, + c.writer, + ) } func (c *serverConn) Upstream() any { diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index ee5add6..890413c 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -162,7 +162,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta MaxPacketSize, ) - err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) + err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) if err != nil { return err } @@ -172,7 +172,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.Cause(err, "read header") } - if headerType != HeaderTypeClient { + if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) } @@ -222,15 +222,24 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta userCtx.Context = ctx userCtx.User = user - metadata.Protocol = "shadowsocks" - metadata.Destination = destination - return s.handler.NewConnection(&userCtx, &serverConn{ + protocolConn := &serverConn{ Service: s.Service, Conn: conn, uPSK: uPSK, - reader: reader, + headerType: headerType, requestSalt: requestSalt, - }, metadata) + } + + switch headerType { + case HeaderTypeClient: + protocolConn.reader = reader + case HeaderTypeClientEncrypted: + protocolConn.reader = NewTLSEncryptedStreamReader(reader) + } + + metadata.Protocol = "shadowsocks" + metadata.Destination = destination + return s.handler.NewConnection(&userCtx, protocolConn, metadata) } func (s *MultiService[U]) WriteIsThreadUnsafe() {