diff --git a/go.mod b/go.mod index b90f4d2..8a0062e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks go 1.18 require ( - github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773 + github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d lukechampine.com/blake3 v1.1.7 ) diff --git a/go.sum b/go.sum index 949a1f3..e8a5e14 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773 h1:n88c8oBC6GWbEW2+F4HkVAji+puSW3GYGrbnXmuVYsw= -github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= +github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f h1:ekLjKIYjtkZNRN1c1IoNcpAsVZNKtO+Qe5cuHOwX0EI= +github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= diff --git a/shadowaead/aead.go b/shadowaead/aead.go index e29789f..27c636b 100644 --- a/shadowaead/aead.go +++ b/shadowaead/aead.go @@ -353,6 +353,40 @@ func (w *Writer) Write(p []byte) (n int, err error) { return } +func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error { + defer buf.ReleaseMulti(buffers) + var index int + var err error + for _, buffer := range buffers { + pLen := buffer.Len() + if pLen > w.maxPacketSize { + _, err = w.Write(buffer.Bytes()) + if err != nil { + return err + } + } else { + if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead { + _, err = w.upstream.Write(w.buffer[:index]) + index = 0 + if err != nil { + return err + } + } + binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen)) + w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil) + increaseNonce(w.nonce) + offset := index + Overhead + PacketLengthBufferSize + w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil) + increaseNonce(w.nonce) + index = offset + pLen + Overhead + } + } + if index > 0 { + _, err = w.upstream.Write(w.buffer[:index]) + } + return err +} + func (w *Writer) Buffer() *buf.Buffer { return buf.With(w.buffer) } diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go index 2de36dd..4f7c5a7 100644 --- a/shadowaead_2022/protocol.go +++ b/shadowaead_2022/protocol.go @@ -41,8 +41,8 @@ const ( RequestHeaderFixedChunkLength = 1 + 8 + 2 PacketMinimalHeaderSize = 30 - HeaderTypeClientEncrypted = 10 - HeaderTypeServerEncrypted = 11 + // HeaderTypeClientEncrypted = 10 + // HeaderTypeServerEncrypted = 11 ) var ( @@ -223,7 +223,7 @@ type clientConn struct { destination M.Socksaddr requestSalt []byte reader io.Reader - writer io.Writer + writer *shadowaead.Writer } func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { @@ -259,11 +259,11 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) func (c *clientConn) writeRequest(payload []byte) error { var headerType byte - if c.encryptedProtocolExtension && isTLSHandshake(payload) { - headerType = HeaderTypeClientEncrypted - } else { - headerType = HeaderTypeClient - } + //if c.encryptedProtocolExtension && isTLSHandshake(payload) { + // headerType = HeaderTypeClientEncrypted + //} else { + headerType = HeaderTypeClient + //} salt := make([]byte, c.keySaltLength) common.Must1(io.ReadFull(rand.Reader, salt)) @@ -301,8 +301,8 @@ func (c *clientConn) writeRequest(payload []byte) error { switch headerType { case HeaderTypeClient: payloadLen = len(payload) - case HeaderTypeClientEncrypted: - payloadLen = readTLSChunkEnd(payload) + // case HeaderTypeClientEncrypted: + // payloadLen = readTLSChunkEnd(payload) } variableLengthHeaderLen += payloadLen common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) @@ -331,7 +331,7 @@ func (c *clientConn) writeRequest(payload []byte) error { c.requestSalt = salt if headerType == HeaderTypeClient { c.writer = writer - } else if headerType == HeaderTypeClientEncrypted { + } /* else if headerType == HeaderTypeClientEncrypted { encryptedWriter := NewTLSEncryptedStreamWriter(writer) if payloadLen < len(payload) { _, err = encryptedWriter.Write(payload[payloadLen:]) @@ -340,7 +340,7 @@ func (c *clientConn) writeRequest(payload []byte) error { } } c.writer = encryptedWriter - } + }*/ return nil } @@ -384,7 +384,7 @@ func (c *clientConn) readResponse() error { if err != nil { return err } - if headerType != HeaderTypeServer && headerType != HeaderTypeServerEncrypted { + if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) } @@ -425,9 +425,9 @@ func (c *clientConn) readResponse() error { } if headerType == HeaderTypeServer { c.reader = reader - } else if headerType == HeaderTypeServerEncrypted { + } /*else if headerType == HeaderTypeServerEncrypted { c.reader = NewTLSEncryptedStreamReader(reader) - } + }*/ return nil } @@ -456,6 +456,21 @@ func (c *clientConn) Write(p []byte) (n int, err error) { return c.writer.Write(p) } +var _ N.VectorisedWriter = (*clientConn)(nil) + +func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error { + if c.writer != nil { + return c.writer.WriteVectorised(buffers) + } + err := c.writeRequest(buffers[0].Bytes()) + if err != nil { + buf.ReleaseMulti(buffers) + return err + } + buffers[0].Release() + return c.writer.WriteVectorised(buffers[1:]) +} + func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { if c.writer == nil { return bufio.ReadFrom0(c, r) diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go index 245f2b9..d41a4b6 100644 --- a/shadowaead_2022/service.go +++ b/shadowaead_2022/service.go @@ -163,7 +163,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return E.Cause(err, "read header") } - if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted { + if headerType != HeaderTypeClient /* && headerType != HeaderTypeClientEncrypted */ { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) } @@ -224,8 +224,8 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M switch headerType { case HeaderTypeClient: protocolConn.reader = reader - case HeaderTypeClientEncrypted: - protocolConn.reader = NewTLSEncryptedStreamReader(reader) + // case HeaderTypeClientEncrypted: + // protocolConn.reader = NewTLSEncryptedStreamReader(reader) } metadata.Protocol = "shadowsocks" @@ -240,7 +240,7 @@ type serverConn struct { access sync.Mutex headerType byte reader io.Reader - writer io.Writer + writer *shadowaead.Writer requestSalt []byte } @@ -275,9 +275,9 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { case HeaderTypeClient: headerType = HeaderTypeServer payloadLen = len(payload) - case HeaderTypeClientEncrypted: - headerType = HeaderTypeServerEncrypted - payloadLen = readTLSChunkEnd(payload) + // case HeaderTypeClientEncrypted: + // headerType = HeaderTypeServerEncrypted + // payloadLen = readTLSChunkEnd(payload) } _headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2) @@ -304,15 +304,15 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { switch headerType { case HeaderTypeServer: c.writer = writer - case HeaderTypeServerEncrypted: - encryptedWriter := NewTLSEncryptedStreamWriter(writer) - if payloadLen < len(payload) { - _, err = encryptedWriter.Write(payload[payloadLen:]) - if err != nil { - return - } - } - c.writer = encryptedWriter + // case HeaderTypeServerEncrypted: + // encryptedWriter := NewTLSEncryptedStreamWriter(writer) + // if payloadLen < len(payload) { + // _, err = encryptedWriter.Write(payload[payloadLen:]) + // if err != nil { + // return + // } + // } + // c.writer = encryptedWriter } n = len(payload) @@ -336,6 +336,25 @@ func (c *serverConn) Write(p []byte) (n int, err error) { return c.writeResponse(p) } +func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error { + if c.writer != nil { + return c.writer.WriteVectorised(buffers) + } + c.access.Lock() + if c.writer != nil { + c.access.Unlock() + return c.writer.WriteVectorised(buffers) + } + defer c.access.Unlock() + _, err := c.writeResponse(buffers[0].Bytes()) + if err != nil { + buf.ReleaseMulti(buffers) + return err + } + buffers[0].Release() + return c.writer.WriteVectorised(buffers[1:]) +} + func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { if c.writer == nil { return bufio.ReadFrom0(c, r) diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index 222dbda..14b5591 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -183,7 +183,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.Cause(err, "read header") } - if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted { + if headerType != HeaderTypeClient /*&& headerType != HeaderTypeClientEncrypted*/ { return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) } @@ -240,8 +240,8 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta switch headerType { case HeaderTypeClient: protocolConn.reader = reader - case HeaderTypeClientEncrypted: - protocolConn.reader = NewTLSEncryptedStreamReader(reader) + // case HeaderTypeClientEncrypted: + // protocolConn.reader = NewTLSEncryptedStreamReader(reader) } metadata.Protocol = "shadowsocks"