Add vectorised interface

This commit is contained in:
世界 2022-08-12 16:27:14 +08:00
parent 7e47fd1a99
commit 484a11603b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 105 additions and 37 deletions

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks
go 1.18
require (
github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773
github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d
lukechampine.com/blake3 v1.1.7
)

4
go.sum
View file

@ -1,8 +1,8 @@
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773 h1:n88c8oBC6GWbEW2+F4HkVAji+puSW3GYGrbnXmuVYsw=
github.com/sagernet/sing v0.0.0-20220811135544-169983a8d773/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f h1:ekLjKIYjtkZNRN1c1IoNcpAsVZNKtO+Qe5cuHOwX0EI=
github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=

View file

@ -353,6 +353,40 @@ func (w *Writer) Write(p []byte) (n int, err error) {
return
}
func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
var index int
var err error
for _, buffer := range buffers {
pLen := buffer.Len()
if pLen > w.maxPacketSize {
_, err = w.Write(buffer.Bytes())
if err != nil {
return err
}
} else {
if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead {
_, err = w.upstream.Write(w.buffer[:index])
index = 0
if err != nil {
return err
}
}
binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen))
w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil)
increaseNonce(w.nonce)
offset := index + Overhead + PacketLengthBufferSize
w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil)
increaseNonce(w.nonce)
index = offset + pLen + Overhead
}
}
if index > 0 {
_, err = w.upstream.Write(w.buffer[:index])
}
return err
}
func (w *Writer) Buffer() *buf.Buffer {
return buf.With(w.buffer)
}

View file

@ -41,8 +41,8 @@ const (
RequestHeaderFixedChunkLength = 1 + 8 + 2
PacketMinimalHeaderSize = 30
HeaderTypeClientEncrypted = 10
HeaderTypeServerEncrypted = 11
// HeaderTypeClientEncrypted = 10
// HeaderTypeServerEncrypted = 11
)
var (
@ -223,7 +223,7 @@ type clientConn struct {
destination M.Socksaddr
requestSalt []byte
reader io.Reader
writer io.Writer
writer *shadowaead.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
@ -259,11 +259,11 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
func (c *clientConn) writeRequest(payload []byte) error {
var headerType byte
if c.encryptedProtocolExtension && isTLSHandshake(payload) {
headerType = HeaderTypeClientEncrypted
} else {
headerType = HeaderTypeClient
}
//if c.encryptedProtocolExtension && isTLSHandshake(payload) {
// headerType = HeaderTypeClientEncrypted
//} else {
headerType = HeaderTypeClient
//}
salt := make([]byte, c.keySaltLength)
common.Must1(io.ReadFull(rand.Reader, salt))
@ -301,8 +301,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
switch headerType {
case HeaderTypeClient:
payloadLen = len(payload)
case HeaderTypeClientEncrypted:
payloadLen = readTLSChunkEnd(payload)
// case HeaderTypeClientEncrypted:
// payloadLen = readTLSChunkEnd(payload)
}
variableLengthHeaderLen += payloadLen
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
@ -331,7 +331,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
c.requestSalt = salt
if headerType == HeaderTypeClient {
c.writer = writer
} else if headerType == HeaderTypeClientEncrypted {
} /* else if headerType == HeaderTypeClientEncrypted {
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
if payloadLen < len(payload) {
_, err = encryptedWriter.Write(payload[payloadLen:])
@ -340,7 +340,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
}
}
c.writer = encryptedWriter
}
}*/
return nil
}
@ -384,7 +384,7 @@ func (c *clientConn) readResponse() error {
if err != nil {
return err
}
if headerType != HeaderTypeServer && headerType != HeaderTypeServerEncrypted {
if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
}
@ -425,9 +425,9 @@ func (c *clientConn) readResponse() error {
}
if headerType == HeaderTypeServer {
c.reader = reader
} else if headerType == HeaderTypeServerEncrypted {
} /*else if headerType == HeaderTypeServerEncrypted {
c.reader = NewTLSEncryptedStreamReader(reader)
}
}*/
return nil
}
@ -456,6 +456,21 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
return c.writer.Write(p)
}
var _ N.VectorisedWriter = (*clientConn)(nil)
func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.writer != nil {
return c.writer.WriteVectorised(buffers)
}
err := c.writeRequest(buffers[0].Bytes())
if err != nil {
buf.ReleaseMulti(buffers)
return err
}
buffers[0].Release()
return c.writer.WriteVectorised(buffers[1:])
}
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return bufio.ReadFrom0(c, r)

View file

@ -163,7 +163,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
if headerType != HeaderTypeClient /* && headerType != HeaderTypeClientEncrypted */ {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
@ -224,8 +224,8 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
switch headerType {
case HeaderTypeClient:
protocolConn.reader = reader
case HeaderTypeClientEncrypted:
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
// case HeaderTypeClientEncrypted:
// protocolConn.reader = NewTLSEncryptedStreamReader(reader)
}
metadata.Protocol = "shadowsocks"
@ -240,7 +240,7 @@ type serverConn struct {
access sync.Mutex
headerType byte
reader io.Reader
writer io.Writer
writer *shadowaead.Writer
requestSalt []byte
}
@ -275,9 +275,9 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
case HeaderTypeClient:
headerType = HeaderTypeServer
payloadLen = len(payload)
case HeaderTypeClientEncrypted:
headerType = HeaderTypeServerEncrypted
payloadLen = readTLSChunkEnd(payload)
// case HeaderTypeClientEncrypted:
// headerType = HeaderTypeServerEncrypted
// payloadLen = readTLSChunkEnd(payload)
}
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
@ -304,15 +304,15 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
switch headerType {
case HeaderTypeServer:
c.writer = writer
case HeaderTypeServerEncrypted:
encryptedWriter := NewTLSEncryptedStreamWriter(writer)
if payloadLen < len(payload) {
_, err = encryptedWriter.Write(payload[payloadLen:])
if err != nil {
return
}
}
c.writer = encryptedWriter
// case HeaderTypeServerEncrypted:
// encryptedWriter := NewTLSEncryptedStreamWriter(writer)
// if payloadLen < len(payload) {
// _, err = encryptedWriter.Write(payload[payloadLen:])
// if err != nil {
// return
// }
// }
// c.writer = encryptedWriter
}
n = len(payload)
@ -336,6 +336,25 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
return c.writeResponse(p)
}
func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.writer != nil {
return c.writer.WriteVectorised(buffers)
}
c.access.Lock()
if c.writer != nil {
c.access.Unlock()
return c.writer.WriteVectorised(buffers)
}
defer c.access.Unlock()
_, err := c.writeResponse(buffers[0].Bytes())
if err != nil {
buf.ReleaseMulti(buffers)
return err
}
buffers[0].Release()
return c.writer.WriteVectorised(buffers[1:])
}
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer == nil {
return bufio.ReadFrom0(c, r)

View file

@ -183,7 +183,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
return E.Cause(err, "read header")
}
if headerType != HeaderTypeClient && headerType != HeaderTypeClientEncrypted {
if headerType != HeaderTypeClient /*&& headerType != HeaderTypeClientEncrypted*/ {
return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
}
@ -240,8 +240,8 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
switch headerType {
case HeaderTypeClient:
protocolConn.reader = reader
case HeaderTypeClientEncrypted:
protocolConn.reader = NewTLSEncryptedStreamReader(reader)
// case HeaderTypeClientEncrypted:
// protocolConn.reader = NewTLSEncryptedStreamReader(reader)
}
metadata.Protocol = "shadowsocks"