diff --git a/common/bufio/conn.go b/common/bufio/conn.go index f30e089..b8a5c03 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -532,6 +532,7 @@ type ExtendedWriterWrapper struct { } func (w *ExtendedWriterWrapper) WriteBuffer(buffer *buf.Buffer) error { + defer buffer.Release() return common.Error(w.Write(buffer.Bytes())) } diff --git a/common/bufio/io.go b/common/bufio/io.go index 217cc97..0dea32b 100644 --- a/common/bufio/io.go +++ b/common/bufio/io.go @@ -22,7 +22,6 @@ func ReadFrom(reader N.PacketReader, buffer *buf.Buffer) (n int, addr net.Addr, } func Write(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error) { - defer buffer.Release() dataLen := buffer.Len() err = writer.WriteBuffer(buffer) if err == nil { @@ -32,7 +31,6 @@ func Write(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error) { } func WriteTo(writer N.PacketWriter, buffer *buf.Buffer, addr net.Addr) (n int, err error) { - defer buffer.Release() dataLen := buffer.Len() err = writer.WritePacket(buffer, M.SocksaddrFromNet(addr)) if err == nil { diff --git a/common/bufio/limit.go b/common/bufio/limit.go deleted file mode 100644 index 98e3e42..0000000 --- a/common/bufio/limit.go +++ /dev/null @@ -1,50 +0,0 @@ -package bufio - -import ( - "io" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - N "github.com/sagernet/sing/common/network" -) - -type LimitedWriter struct { - upstream N.ExtendedWriter - maxChunkLength int -} - -func NewLimitedWriter(writer io.Writer, maxChunkLength int) *LimitedWriter { - return &LimitedWriter{ - upstream: NewExtendedWriter(writer), - maxChunkLength: maxChunkLength, - } -} - -func (w *LimitedWriter) Write(p []byte) (n int, err error) { - for pLen := len(p); pLen > 0; { - var data []byte - if pLen > w.maxChunkLength { - data = p[:w.maxChunkLength] - p = p[w.maxChunkLength:] - pLen -= w.maxChunkLength - } else { - data = p - pLen = 0 - } - var writeN int - writeN, err = w.upstream.Write(data) - if err != nil { - return - } - n += writeN - } - return -} - -func (w *LimitedWriter) WriteBuffer(buffer *buf.Buffer) error { - if buffer.Len() <= w.maxChunkLength { - return w.upstream.WriteBuffer(buffer) - } - defer buffer.Release() - return common.Error(w.Write(buffer.Bytes())) -} diff --git a/common/bufio/once.go b/common/bufio/once.go index 6e048d6..7c37e31 100644 --- a/common/bufio/once.go +++ b/common/bufio/once.go @@ -22,7 +22,7 @@ func CopyOnce(dst io.Writer, src io.Reader) (n int64, err error) { func CopyExtendedOnce(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { var buffer *buf.Buffer - if _, unsafe := common.Cast[N.ThreadUnsafeWriter](dst); unsafe { + if N.IsUnsafeWriter(dst) { buffer = buf.New() } else { _buffer := buf.StackNew() diff --git a/common/network/thread.go b/common/network/thread.go index 3ef0e17..dca2073 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -22,6 +22,11 @@ type HeadroomWriter interface { Headroom() int } +func IsUnsafeWriter(writer any) bool { + _, isUnsafe := common.Cast[ThreadUnsafeWriter](writer) + return isUnsafe +} + func CalculateHeadroom(writer any) int { var headroom int if headroomWriter, needHeadroom := writer.(HeadroomWriter); needHeadroom { diff --git a/common/uot/server.go b/common/uot/server.go index efd0c6e..b8605bb 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -44,6 +44,7 @@ func (c *ServerConn) RemoteAddr() net.Addr { return pipeAddr{} } +//warn:unsafe func (c *ServerConn) loopInput() { _buffer := buf.StackNew() defer common.KeepAlive(_buffer) @@ -79,6 +80,7 @@ func (c *ServerConn) loopInput() { c.Close() } +//warn:unsafe func (c *ServerConn) loopOutput() { _buffer := buf.StackNew() defer common.KeepAlive(_buffer) diff --git a/go.mod b/go.mod index 5e43a31..228feda 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 +require golang.org/x/sys v0.0.0-20220731174439-a90be440212d diff --git a/go.sum b/go.sum index 28ab7f7..b7c2d11 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 516e3be..1b59574 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -63,10 +63,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read var requestConn net.Conn if reader.Buffered() > 0 { - _buffer := buf.StackNewSize(reader.Buffered()) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() + buffer := buf.NewSize(reader.Buffered()) _, err = buffer.ReadFullFrom(reader, reader.Buffered()) if err != nil { return err diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index c228c34..c95c164 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -43,6 +43,7 @@ func (c *AssociatePacketConn) RemoteAddr() net.Addr { return c.addr } +//warn:unsafe func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { buffer := buf.With(p) n, _, err = bufio.ReadFrom(c.PacketConn, buffer) @@ -59,6 +60,7 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro return } +//warn:unsafe func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination := M.SocksaddrFromNet(addr) _buffer := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) @@ -98,7 +100,6 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error } func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))