Remove stack buffer usage

This commit is contained in:
世界 2023-07-03 21:22:53 +08:00
parent 513f49a03f
commit 2cedde0fbc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 20 additions and 63 deletions

View file

@ -1,7 +1,7 @@
fmt: fmt:
@gofumpt -l -w . @gofumpt -l -w .
@gofmt -s -w . @gofmt -s -w .
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" . @gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" .
fmt_install: fmt_install:
go install -v mvdan.cc/gofumpt@latest go install -v mvdan.cc/gofumpt@latest

View file

@ -7,7 +7,6 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -50,9 +49,7 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
Network: N.NetworkTCP, Network: N.NetworkTCP,
Destination: c.destination, Destination: c.destination,
} }
_buffer := buf.StackNewSize(streamRequestLen(request) + len(b)) buffer := buf.NewSize(streamRequestLen(request) + len(b))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
EncodeStreamRequest(request, buffer) EncodeStreamRequest(request, buffer)
buffer.Write(b) buffer.Write(b)
@ -64,20 +61,6 @@ func (c *clientConn) Write(b []byte) (n int, err error) {
return len(b), nil return len(b), nil
} }
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.requestWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
}
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if !c.responseRead {
return bufio.WriteTo0(c, w)
}
return bufio.Copy(w, c.Conn)
}
func (c *clientConn) LocalAddr() net.Addr { func (c *clientConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr() return c.Conn.LocalAddr()
} }
@ -148,9 +131,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
if len(payload) > 0 { if len(payload) > 0 {
rLen += 2 + len(payload) rLen += 2 + len(payload)
} }
_buffer := buf.StackNewSize(rLen) buffer := buf.NewSize(rLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
EncodeStreamRequest(request, buffer) EncodeStreamRequest(request, buffer)
if len(payload) > 0 { if len(payload) > 0 {
@ -324,9 +305,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
if len(payload) > 0 { if len(payload) > 0 {
rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
} }
_buffer := buf.StackNewSize(rLen) buffer := buf.NewSize(rLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
EncodeStreamRequest(request, buffer) EncodeStreamRequest(request, buffer)
if len(payload) > 0 { if len(payload) > 0 {

2
go.mod
View file

@ -4,7 +4,7 @@ go 1.18
require ( require (
github.com/hashicorp/yamux v0.1.1 github.com/hashicorp/yamux v0.1.1
github.com/sagernet/sing v0.2.5 github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
golang.org/x/net v0.11.0 golang.org/x/net v0.11.0
) )

4
go.sum
View file

@ -1,8 +1,8 @@
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.5 h1:N8sUluR8GZvR9DqUiH3FA3vBb4m/EDdOVTYUrDzJvmY= github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059 h1:nqTONy58Gq1mdoGx9GX+GKXdSTwOPTKF/DXK+Wn4B+A=
github.com/sagernet/sing v0.2.5/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/sagernet/sing v0.2.8-0.20230703002104-c68251b6d059/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=

View file

@ -66,9 +66,7 @@ func (c *paddingConn) Read(p []byte) (n int, err error) {
if len(p) >= 4 { if len(p) >= 4 {
paddingHdr = p[:4] paddingHdr = p[:4]
} else { } else {
_paddingHdr := make([]byte, 4) paddingHdr = make([]byte, 4)
defer common.KeepAlive(_paddingHdr)
paddingHdr = common.Dup(_paddingHdr)
} }
_, err = io.ReadFull(c.ExtendedConn, paddingHdr) _, err = io.ReadFull(c.ExtendedConn, paddingHdr)
if err != nil { if err != nil {
@ -115,9 +113,7 @@ func (c *paddingConn) Write(p []byte) (n int, err error) {
func (c *paddingConn) write(p []byte) (n int, err error) { func (c *paddingConn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
paddingLen := 256 + rand.Intn(512) paddingLen := 256 + rand.Intn(512)
_buffer := buf.StackNewSize(4 + len(p) + paddingLen) buffer := buf.NewSize(4 + len(p) + paddingLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
header := buffer.Extend(4) header := buffer.Extend(4)
binary.BigEndian.PutUint16(header[:2], uint16(len(p))) binary.BigEndian.PutUint16(header[:2], uint16(len(p)))
@ -160,9 +156,7 @@ func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error {
if len(p) >= 4 { if len(p) >= 4 {
paddingHdr = p[:4] paddingHdr = p[:4]
} else { } else {
_paddingHdr := make([]byte, 4) paddingHdr = make([]byte, 4)
defer common.KeepAlive(_paddingHdr)
paddingHdr = common.Dup(_paddingHdr)
} }
_, err := io.ReadFull(c.ExtendedConn, paddingHdr) _, err := io.ReadFull(c.ExtendedConn, paddingHdr)
if err != nil { if err != nil {

View file

@ -1,7 +1,6 @@
package mux package mux
import ( import (
"io"
"net" "net"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -47,13 +46,6 @@ func (c *protocolConn) Write(p []byte) (n int, err error) {
return n, err return n, err
} }
func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.requestWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
}
func (c *protocolConn) Upstream() any { func (c *protocolConn) Upstream() any {
return c.Conn return c.Conn
} }

View file

@ -21,24 +21,20 @@ type serverConn struct {
func (c *serverConn) HandshakeFailure(err error) error { func (c *serverConn) HandshakeFailure(err error) error {
errMessage := err.Error() errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
common.Must( common.Must(
buffer.WriteByte(statusError), buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage), rw.WriteVString(buffer, errMessage),
) )
return c.ExtendedConn.WriteBuffer(buffer) return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
} }
func (c *serverConn) Write(b []byte) (n int, err error) { func (c *serverConn) Write(b []byte) (n int, err error) {
if c.responseWritten { if c.responseWritten {
return c.ExtendedConn.Write(b) return c.ExtendedConn.Write(b)
} }
_buffer := buf.StackNewSize(1 + len(b)) buffer := buf.NewSize(1 + len(b))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
common.Must( common.Must(
buffer.WriteByte(statusSuccess), buffer.WriteByte(statusSuccess),
@ -89,15 +85,13 @@ type serverPacketConn struct {
func (c *serverPacketConn) HandshakeFailure(err error) error { func (c *serverPacketConn) HandshakeFailure(err error) error {
errMessage := err.Error() errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
common.Must( common.Must(
buffer.WriteByte(statusError), buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage), rw.WriteVString(buffer, errMessage),
) )
return c.ExtendedConn.WriteBuffer(buffer) return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
} }
func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
@ -178,15 +172,13 @@ type serverPacketAddrConn struct {
func (c *serverPacketAddrConn) HandshakeFailure(err error) error { func (c *serverPacketAddrConn) HandshakeFailure(err error) error {
errMessage := err.Error() errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
common.Must( common.Must(
buffer.WriteByte(statusError), buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage), rw.WriteVString(buffer, errMessage),
) )
return c.ExtendedConn.WriteBuffer(buffer) return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
} }
func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {