mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-04 12:27:39 +03:00
Draft: Encrypted Protocol Extension
This commit is contained in:
parent
cd8ec2833c
commit
e3a6eb8580
8 changed files with 341 additions and 53 deletions
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks
|
|||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc
|
||||
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507
|
||||
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
|
||||
lukechampine.com/blake3 v1.1.7
|
||||
)
|
||||
|
|
4
go.sum
4
go.sum
|
@ -1,8 +1,8 @@
|
|||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
|
||||
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
|
||||
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc h1:AdNTzzSw6SCZI71GB+Am7cr+oUDUrBUaOi17FxDtNMw=
|
||||
github.com/sagernet/sing v0.0.0-20220614091938-64835a637bdc/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4=
|
||||
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507 h1:rMYMyB6N0ARFg0bwgG1Ahl+h0HCXO74yzT8PYvxOuPs=
|
||||
github.com/sagernet/sing v0.0.0-20220614131337-ea019b365507/go.mod h1:Bgwxr10oTxYlQ33MgsXW3GuS2w5St11qqk4DqzJOdU4=
|
||||
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
|
||||
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d h1:Zu/JngovGLVi6t2J3nmAf3AoTDwuzw85YZ3b9o4yU7s=
|
||||
|
|
|
@ -197,6 +197,12 @@ func (r *Reader) Discard(n int) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Buffer() *buf.Buffer {
|
||||
buffer := buf.With(r.buffer)
|
||||
buffer.Resize(r.index, r.cached)
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (r *Reader) Cached() int {
|
||||
return r.cached
|
||||
}
|
||||
|
@ -243,7 +249,7 @@ func (r *Reader) ReadWithLength(length uint16) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadChunk(chunk []byte) error {
|
||||
func (r *Reader) ReadExternalChunk(chunk []byte) error {
|
||||
bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -254,6 +260,16 @@ func (r *Reader) ReadChunk(chunk []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error {
|
||||
bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
buffer.Extend(len(bb))
|
||||
return nil
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
|
|
169
shadowaead_2022/encrypted_stream.go
Normal file
169
shadowaead_2022/encrypted_stream.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package shadowaead_2022
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeHandshake = 22
|
||||
recordTypeApplicationData = 23
|
||||
|
||||
tlsVersion10 = 0x0301
|
||||
tlsVersion11 = 0x0302
|
||||
tlsVersion12 = 0x0303
|
||||
tlsVersion13 = 0x0304
|
||||
|
||||
tlsEncryptedLengthChunkLength = 5 + shadowaead.Overhead
|
||||
)
|
||||
|
||||
func isTLSHandshake(payload []byte) bool {
|
||||
if len(payload) < 5 {
|
||||
return false
|
||||
}
|
||||
if payload[0] != recordTypeHandshake {
|
||||
return false
|
||||
}
|
||||
tlsVersion := binary.BigEndian.Uint16(payload[1:])
|
||||
if tlsVersion < tlsVersion10 || tlsVersion > tlsVersion13 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func readTLSChunkEnd(payload []byte) int {
|
||||
pLen := len(payload)
|
||||
index := 0
|
||||
for index < pLen {
|
||||
if pLen-index < 5 {
|
||||
break
|
||||
}
|
||||
dataLen := binary.BigEndian.Uint16(payload[index+3 : index+5])
|
||||
nextIndex := index + 5 + int(dataLen)
|
||||
if nextIndex > pLen {
|
||||
return index
|
||||
}
|
||||
index = nextIndex
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
type TLSEncryptedStreamReader struct {
|
||||
upstream *shadowaead.Reader
|
||||
raw io.Reader
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func NewTLSEncryptedStreamReader(upstream *shadowaead.Reader) *TLSEncryptedStreamReader {
|
||||
var reader TLSEncryptedStreamReader
|
||||
reader.upstream = upstream
|
||||
reader.raw = upstream.Upstream().(io.Reader)
|
||||
reader.buffer = upstream.Buffer()
|
||||
return &reader
|
||||
}
|
||||
|
||||
func (r *TLSEncryptedStreamReader) Read(p []byte) (n int, err error) {
|
||||
if !r.buffer.IsEmpty() {
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
data := r.buffer.Slice()
|
||||
_, err = io.ReadFull(r.raw, data[:tlsEncryptedLengthChunkLength])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.buffer.FullReset()
|
||||
err = r.upstream.ReadChunk(r.buffer, data[:tlsEncryptedLengthChunkLength])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
recordType := data[0]
|
||||
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
if recordType == recordTypeApplicationData {
|
||||
_, err = r.buffer.ReadFullFrom(r.raw, recordLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err = io.ReadFull(r.raw, data[5:5+recordLen+shadowaead.Overhead])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.upstream.ReadChunk(r.buffer, data[5:5+recordLen+shadowaead.Overhead])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return r.buffer.Read(p)
|
||||
}
|
||||
|
||||
type TLSEncryptedStreamWriter struct {
|
||||
upstream *shadowaead.Writer
|
||||
raw io.Writer
|
||||
buffer *buf.Buffer
|
||||
pipeIn *io.PipeReader
|
||||
pipeOut *io.PipeWriter
|
||||
}
|
||||
|
||||
func NewTLSEncryptedStreamWriter(upstream *shadowaead.Writer) *TLSEncryptedStreamWriter {
|
||||
var writer TLSEncryptedStreamWriter
|
||||
writer.upstream = upstream
|
||||
writer.raw = upstream.Upstream().(io.Writer)
|
||||
writer.buffer = upstream.Buffer()
|
||||
writer.pipeIn, writer.pipeOut = io.Pipe()
|
||||
go writer.loopOut()
|
||||
return &writer
|
||||
}
|
||||
|
||||
func (w *TLSEncryptedStreamWriter) Write(p []byte) (n int, err error) {
|
||||
return w.pipeOut.Write(p)
|
||||
}
|
||||
|
||||
func (w *TLSEncryptedStreamWriter) loopOut() {
|
||||
data := w.buffer.Slice()
|
||||
var err error
|
||||
for {
|
||||
_, err = io.ReadFull(w.pipeIn, data[:5])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
recordType := data[0]
|
||||
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
w.buffer.FullReset()
|
||||
w.upstream.WriteChunk(w.buffer, data[:5])
|
||||
|
||||
if recordType != recordTypeApplicationData {
|
||||
_, err = io.ReadFull(w.pipeIn, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.upstream.WriteChunk(w.buffer, data[tlsEncryptedLengthChunkLength:tlsEncryptedLengthChunkLength+recordLen])
|
||||
} else {
|
||||
_, err = w.buffer.ReadFullFrom(w.pipeIn, recordLen)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_, err = w.raw.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
w.pipeIn.CloseWithError(err)
|
||||
}
|
||||
|
||||
func (w *TLSEncryptedStreamWriter) Close() error {
|
||||
return common.Close(
|
||||
w.upstream,
|
||||
w.pipeOut,
|
||||
)
|
||||
}
|
||||
|
||||
func (w *TLSEncryptedStreamWriter) Upstream() any {
|
||||
return w.upstream
|
||||
}
|
|
@ -39,6 +39,9 @@ const (
|
|||
PacketNonceSize = 24
|
||||
MaxPacketSize = 65535
|
||||
RequestHeaderFixedChunkLength = 1 + 8 + 2
|
||||
|
||||
HeaderTypeClientEncrypted = 10
|
||||
HeaderTypeServerEncrypted = 11
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -58,7 +61,7 @@ var List = []string{
|
|||
"2022-blake3-chacha20-poly1305",
|
||||
}
|
||||
|
||||
func NewWithPassword(method string, password string) (shadowsocks.Method, error) {
|
||||
func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) {
|
||||
var pskList [][]byte
|
||||
if password == "" {
|
||||
return nil, ErrMissingPSK
|
||||
|
@ -72,10 +75,10 @@ func NewWithPassword(method string, password string) (shadowsocks.Method, error)
|
|||
}
|
||||
pskList[i] = kb
|
||||
}
|
||||
return New(method, pskList)
|
||||
return New(method, pskList, options...)
|
||||
}
|
||||
|
||||
func New(method string, pskList [][]byte) (shadowsocks.Method, error) {
|
||||
func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) {
|
||||
m := &Method{
|
||||
name: method,
|
||||
replayFilter: replay.NewSimple(60 * time.Second),
|
||||
|
@ -134,6 +137,9 @@ func New(method string, pskList [][]byte) (shadowsocks.Method, error) {
|
|||
}
|
||||
|
||||
m.pskList = pskList
|
||||
for _, option := range options {
|
||||
option(m)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
@ -171,6 +177,7 @@ type Method struct {
|
|||
pskList [][]byte
|
||||
pskHash []byte
|
||||
replayFilter replay.Filter
|
||||
encryptedProtocolExtension bool
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
|
@ -203,8 +210,8 @@ type clientConn struct {
|
|||
net.Conn
|
||||
destination M.Socksaddr
|
||||
requestSalt []byte
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
|
||||
|
@ -239,6 +246,13 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
|||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
var headerType byte
|
||||
if c.encryptedProtocolExtension && isTLSHandshake(payload) {
|
||||
headerType = HeaderTypeClientEncrypted
|
||||
} else {
|
||||
headerType = HeaderTypeClient
|
||||
}
|
||||
|
||||
salt := make([]byte, c.keySaltLength)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
|
@ -264,13 +278,21 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
|
||||
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
|
||||
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
|
||||
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
|
||||
common.Must(fixedLengthBuffer.WriteByte(headerType))
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
var paddingLen int
|
||||
if len(payload) < MaxPaddingLength {
|
||||
paddingLen = mRand.Intn(MaxPaddingLength) + 1
|
||||
}
|
||||
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload)
|
||||
variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen
|
||||
var payloadLen int
|
||||
switch headerType {
|
||||
case HeaderTypeClient:
|
||||
payloadLen = len(payload)
|
||||
case HeaderTypeClientEncrypted:
|
||||
payloadLen = readTLSChunkEnd(payload)
|
||||
}
|
||||
variableLengthHeaderLen += payloadLen
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
|
||||
writer.WriteChunk(header, fixedLengthBuffer.Slice())
|
||||
common.KeepAlive(_fixedLengthBuffer)
|
||||
|
@ -282,8 +304,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
if paddingLen > 0 {
|
||||
variableLengthBuffer.Extend(paddingLen)
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
common.Must1(variableLengthBuffer.Write(payload))
|
||||
if payloadLen > 0 {
|
||||
common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
|
||||
}
|
||||
writer.WriteChunk(header, variableLengthBuffer.Slice())
|
||||
common.KeepAlive(_variableLengthBuffer)
|
||||
|
@ -295,7 +317,18 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
}
|
||||
|
||||
c.requestSalt = salt
|
||||
if headerType == HeaderTypeClient {
|
||||
c.writer = writer
|
||||
} else if headerType == HeaderTypeClientEncrypted {
|
||||
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
|
||||
if payloadLen < len(payload) {
|
||||
_, err = encryptedWriter.Write(payload[payloadLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.writer = encryptedWriter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -346,7 +379,7 @@ func (c *clientConn) readResponse() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if headerType != HeaderTypeServer {
|
||||
if headerType != HeaderTypeServer && headerType != HeaderTypeServerEncrypted {
|
||||
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
|
||||
}
|
||||
|
||||
|
@ -373,6 +406,7 @@ func (c *clientConn) readResponse() error {
|
|||
}
|
||||
requestSalt.Release()
|
||||
common.KeepAlive(_requestSalt)
|
||||
c.requestSalt = nil
|
||||
|
||||
var length uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &length)
|
||||
|
@ -384,10 +418,11 @@ func (c *clientConn) readResponse() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.requestSalt = nil
|
||||
if headerType == HeaderTypeServer {
|
||||
c.reader = reader
|
||||
|
||||
} else if headerType == HeaderTypeServerEncrypted {
|
||||
c.reader = NewTLSEncryptedStreamReader(reader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -402,7 +437,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.WriteTo(w)
|
||||
return bufio.Copy(w, c.reader)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
|
@ -420,13 +455,21 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if c.writer == nil {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
return bufio.Copy(c.writer, r)
|
||||
}
|
||||
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientConn) Close() error {
|
||||
return common.Close(
|
||||
c.Conn,
|
||||
c.reader,
|
||||
c.writer,
|
||||
)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
|
|
9
shadowaead_2022/protocol_option.go
Normal file
9
shadowaead_2022/protocol_option.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package shadowaead_2022
|
||||
|
||||
type MethodOption func(*Method)
|
||||
|
||||
func MethodOptionEncryptedProtocolExtension() MethodOption {
|
||||
return func(method *Method) {
|
||||
method.encryptedProtocolExtension = true
|
||||
}
|
||||
}
|
|
@ -153,7 +153,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
)
|
||||
common.KeepAlive(requestKey)
|
||||
|
||||
err = reader.ReadChunk(header[s.keySaltLength:])
|
||||
err = reader.ReadExternalChunk(header[s.keySaltLength:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return E.Cause(err, "read header")
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
|
||||
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
||||
}
|
||||
|
||||
|
@ -213,15 +213,24 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return ErrNoPadding
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(ctx, &serverConn{
|
||||
protocolConn := &serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
uPSK: s.psk,
|
||||
reader: reader,
|
||||
headerType: headerType,
|
||||
requestSalt: requestSalt,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
switch headerType {
|
||||
case HeaderTypeClient:
|
||||
protocolConn.reader = reader
|
||||
case HeaderTypeClientEncrypted:
|
||||
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(ctx, protocolConn, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
|
@ -229,8 +238,9 @@ type serverConn struct {
|
|||
net.Conn
|
||||
uPSK []byte
|
||||
access sync.Mutex
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
headerType byte
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
requestSalt []byte
|
||||
}
|
||||
|
||||
|
@ -259,20 +269,31 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|||
salt.Release()
|
||||
common.KeepAlive(_salt)
|
||||
|
||||
var headerType byte
|
||||
var payloadLen int
|
||||
switch c.headerType {
|
||||
case HeaderTypeClient:
|
||||
headerType = HeaderTypeServer
|
||||
payloadLen = len(payload)
|
||||
case HeaderTypeClientEncrypted:
|
||||
headerType = HeaderTypeServerEncrypted
|
||||
payloadLen = readTLSChunkEnd(payload)
|
||||
}
|
||||
|
||||
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
|
||||
headerFixedChunk := common.Dup(_headerFixedChunk)
|
||||
common.Must(headerFixedChunk.WriteByte(HeaderTypeServer))
|
||||
common.Must(headerFixedChunk.WriteByte(headerType))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must1(headerFixedChunk.Write(c.requestSalt))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload))))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
|
||||
|
||||
writer.WriteChunk(header, headerFixedChunk.Slice())
|
||||
headerFixedChunk.Release()
|
||||
common.KeepAlive(_headerFixedChunk)
|
||||
c.requestSalt = nil
|
||||
|
||||
if len(payload) > 0 {
|
||||
writer.WriteChunk(header, payload)
|
||||
if payloadLen > 0 {
|
||||
writer.WriteChunk(header, payload[:payloadLen])
|
||||
}
|
||||
|
||||
err = writer.BufferedWriter(header.Len()).Flush()
|
||||
|
@ -280,7 +301,20 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
switch headerType {
|
||||
case HeaderTypeServer:
|
||||
c.writer = writer
|
||||
case HeaderTypeServerEncrypted:
|
||||
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
|
||||
if payloadLen < len(payload) {
|
||||
_, err = encryptedWriter.Write(payload[payloadLen:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
c.writer = encryptedWriter
|
||||
}
|
||||
|
||||
n = len(payload)
|
||||
return
|
||||
}
|
||||
|
@ -302,11 +336,19 @@ func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if c.writer == nil {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
return bufio.Copy(c.writer, r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
return bufio.Copy(w, c.reader)
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
return common.Close(
|
||||
c.Conn,
|
||||
c.reader,
|
||||
c.writer,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *serverConn) Upstream() any {
|
||||
|
|
|
@ -162,7 +162,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
MaxPacketSize,
|
||||
)
|
||||
|
||||
err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
|
||||
err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
return E.Cause(err, "read header")
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
|
||||
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
|
||||
}
|
||||
|
||||
|
@ -222,15 +222,24 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
userCtx.Context = ctx
|
||||
userCtx.User = user
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(&userCtx, &serverConn{
|
||||
protocolConn := &serverConn{
|
||||
Service: s.Service,
|
||||
Conn: conn,
|
||||
uPSK: uPSK,
|
||||
reader: reader,
|
||||
headerType: headerType,
|
||||
requestSalt: requestSalt,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
switch headerType {
|
||||
case HeaderTypeClient:
|
||||
protocolConn.reader = reader
|
||||
case HeaderTypeClientEncrypted:
|
||||
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(&userCtx, protocolConn, metadata)
|
||||
}
|
||||
|
||||
func (s *MultiService[U]) WriteIsThreadUnsafe() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue