From 2cedde0fbc906a7047075e5ebd9cc603adfcec56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 3 Jul 2023 21:22:53 +0800 Subject: [PATCH] Remove stack buffer usage --- Makefile | 2 +- client_conn.go | 27 +++------------------------ go.mod | 2 +- go.sum | 4 ++-- padding.go | 12 +++--------- protocol_conn.go | 8 -------- server_conn.go | 28 ++++++++++------------------ 7 files changed, 20 insertions(+), 63 deletions(-) diff --git a/Makefile b/Makefile index a837f60..e47456e 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ fmt: @gofumpt -l -w . @gofmt -s -w . - @gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" . + @gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" . fmt_install: go install -v mvdan.cc/gofumpt@latest diff --git a/client_conn.go b/client_conn.go index 2304a9c..9a82621 100644 --- a/client_conn.go +++ b/client_conn.go @@ -7,7 +7,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -50,9 +49,7 @@ func (c *clientConn) Write(b []byte) (n int, err error) { Network: N.NetworkTCP, Destination: c.destination, } - _buffer := buf.StackNewSize(streamRequestLen(request) + len(b)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(streamRequestLen(request) + len(b)) defer buffer.Release() EncodeStreamRequest(request, buffer) buffer.Write(b) @@ -64,20 +61,6 @@ func (c *clientConn) Write(b []byte) (n int, err error) { return len(b), nil } -func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { - if !c.requestWritten { - return bufio.ReadFrom0(c, r) - } - return bufio.Copy(c.Conn, r) -} - -func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { - if !c.responseRead { - return bufio.WriteTo0(c, w) - } - return bufio.Copy(w, c.Conn) -} - func (c *clientConn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } @@ -148,9 +131,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { if len(payload) > 0 { rLen += 2 + len(payload) } - _buffer := buf.StackNewSize(rLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(rLen) defer buffer.Release() EncodeStreamRequest(request, buffer) if len(payload) > 0 { @@ -324,9 +305,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa if len(payload) > 0 { rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) } - _buffer := buf.StackNewSize(rLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(rLen) defer buffer.Release() EncodeStreamRequest(request, buffer) if len(payload) > 0 { diff --git a/go.mod b/go.mod index b9db591..b081bd7 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/hashicorp/yamux v0.1.1 - github.com/sagernet/sing v0.2.5 + github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 golang.org/x/net v0.11.0 ) diff --git a/go.sum b/go.sum index 2761143..b004200 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.5 h1:N8sUluR8GZvR9DqUiH3FA3vBb4m/EDdOVTYUrDzJvmY= -github.com/sagernet/sing v0.2.5/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059 h1:nqTONy58Gq1mdoGx9GX+GKXdSTwOPTKF/DXK+Wn4B+A= +github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= diff --git a/padding.go b/padding.go index 850bf25..d902aae 100644 --- a/padding.go +++ b/padding.go @@ -66,9 +66,7 @@ func (c *paddingConn) Read(p []byte) (n int, err error) { if len(p) >= 4 { paddingHdr = p[:4] } else { - _paddingHdr := make([]byte, 4) - defer common.KeepAlive(_paddingHdr) - paddingHdr = common.Dup(_paddingHdr) + paddingHdr = make([]byte, 4) } _, err = io.ReadFull(c.ExtendedConn, paddingHdr) if err != nil { @@ -115,9 +113,7 @@ func (c *paddingConn) Write(p []byte) (n int, err error) { func (c *paddingConn) write(p []byte) (n int, err error) { if c.writePadding < kFirstPaddings { paddingLen := 256 + rand.Intn(512) - _buffer := buf.StackNewSize(4 + len(p) + paddingLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(4 + len(p) + paddingLen) defer buffer.Release() header := buffer.Extend(4) binary.BigEndian.PutUint16(header[:2], uint16(len(p))) @@ -160,9 +156,7 @@ func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error { if len(p) >= 4 { paddingHdr = p[:4] } else { - _paddingHdr := make([]byte, 4) - defer common.KeepAlive(_paddingHdr) - paddingHdr = common.Dup(_paddingHdr) + paddingHdr = make([]byte, 4) } _, err := io.ReadFull(c.ExtendedConn, paddingHdr) if err != nil { diff --git a/protocol_conn.go b/protocol_conn.go index ec3d2a4..aaf3ffe 100644 --- a/protocol_conn.go +++ b/protocol_conn.go @@ -1,7 +1,6 @@ package mux import ( - "io" "net" "github.com/sagernet/sing/common/buf" @@ -47,13 +46,6 @@ func (c *protocolConn) Write(p []byte) (n int, err error) { return n, err } -func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) { - if !c.requestWritten { - return bufio.ReadFrom0(c, r) - } - return bufio.Copy(c.Conn, r) -} - func (c *protocolConn) Upstream() any { return c.Conn } diff --git a/server_conn.go b/server_conn.go index a37b64a..fa7b4e6 100644 --- a/server_conn.go +++ b/server_conn.go @@ -21,24 +21,20 @@ type serverConn struct { func (c *serverConn) HandshakeFailure(err error) error { errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), + rw.WriteVString(buffer, errMessage), ) - return c.ExtendedConn.WriteBuffer(buffer) + return common.Error(c.ExtendedConn.Write(buffer.Bytes())) } func (c *serverConn) Write(b []byte) (n int, err error) { if c.responseWritten { return c.ExtendedConn.Write(b) } - _buffer := buf.StackNewSize(1 + len(b)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(1 + len(b)) defer buffer.Release() common.Must( buffer.WriteByte(statusSuccess), @@ -89,15 +85,13 @@ type serverPacketConn struct { func (c *serverPacketConn) HandshakeFailure(err error) error { errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), + rw.WriteVString(buffer, errMessage), ) - return c.ExtendedConn.WriteBuffer(buffer) + return common.Error(c.ExtendedConn.Write(buffer.Bytes())) } func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { @@ -178,15 +172,13 @@ type serverPacketAddrConn struct { func (c *serverPacketAddrConn) HandshakeFailure(err error) error { errMessage := err.Error() - _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(_buffer, errMessage), + rw.WriteVString(buffer, errMessage), ) - return c.ExtendedConn.WriteBuffer(buffer) + return common.Error(c.ExtendedConn.Write(buffer.Bytes())) } func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {