diff --git a/common/bufio/fallback.go b/common/bufio/fallback.go index 4ea87cf..bd4ab46 100644 --- a/common/bufio/fallback.go +++ b/common/bufio/fallback.go @@ -3,6 +3,7 @@ package bufio import ( "net" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -12,13 +13,17 @@ var _ N.NetPacketConn = (*FallbackPacketConn)(nil) type FallbackPacketConn struct { N.PacketConn + writer N.NetPacketWriter } func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn { if packetConn, loaded := conn.(N.NetPacketConn); loaded { return packetConn } - return &FallbackPacketConn{PacketConn: conn} + return &FallbackPacketConn{ + PacketConn: conn, + writer: NewNetPacketWriter(conn), + } } func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { @@ -36,11 +41,7 @@ func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error } func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) - if err == nil { - n = len(p) - } - return + return c.writer.WriteTo(p, addr) } func (c *FallbackPacketConn) ReaderReplaceable() bool { @@ -54,3 +55,50 @@ func (c *FallbackPacketConn) WriterReplaceable() bool { func (c *FallbackPacketConn) Upstream() any { return c.PacketConn } + +func (c *FallbackPacketConn) UpstreamWriter() any { + return c.writer +} + +var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil) + +type FallbackPacketWriter struct { + N.PacketWriter + frontHeadroom int + rearHeadroom int +} + +func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter { + if packetWriter, loaded := writer.(N.NetPacketWriter); loaded { + return packetWriter + } + return &FallbackPacketWriter{ + PacketWriter: writer, + frontHeadroom: N.CalculateFrontHeadroom(writer), + rearHeadroom: N.CalculateRearHeadroom(writer), + } +} + +func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.frontHeadroom > 0 || c.rearHeadroom > 0 { + buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom) + buffer.Resize(c.frontHeadroom, 0) + common.Must1(buffer.Write(p)) + err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr)) + } else { + err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) + } + if err != nil { + return + } + n = len(p) + return +} + +func (c *FallbackPacketWriter) WriterReplaceable() bool { + return true +} + +func (c *FallbackPacketWriter) Upstream() any { + return c.PacketWriter +} diff --git a/common/bufio/io.go b/common/bufio/io.go index 1e5d89b..a25a7cc 100644 --- a/common/bufio/io.go +++ b/common/bufio/io.go @@ -37,13 +37,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error) frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release() @@ -69,13 +63,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M. frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release()