mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Fix copy with src buffer
This commit is contained in:
parent
b5294f57ba
commit
5945fd0457
5 changed files with 82 additions and 13 deletions
|
@ -164,6 +164,12 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
|
|||
return c.PacketConn.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
|
||||
buffer = c.buffer
|
||||
c.buffer = nil
|
||||
return c.destination, buffer
|
||||
}
|
||||
|
||||
func (c *CachedPacketConn) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
|
|
@ -215,7 +215,7 @@ func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return Copy(w, r.reader)
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) UpstreamReader() io.Reader {
|
||||
func (w *ExtendedConnWrapper) UpstreamReader() any {
|
||||
return w.reader
|
||||
}
|
||||
|
||||
|
@ -223,7 +223,7 @@ func (w *ExtendedConnWrapper) ReaderReplaceable() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer {
|
||||
func (w *ExtendedConnWrapper) UpstreamWriter() any {
|
||||
return w.writer
|
||||
}
|
||||
|
||||
|
|
|
@ -57,11 +57,11 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
|
||||
safeSrc := N.IsSafeReader(src)
|
||||
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
|
||||
if srcUnsafe {
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyExtendedWithSrcBuffer(dst, unsafeSrc)
|
||||
return CopyExtendedWithSrcBuffer(dst, safeSrc)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
|
@ -189,13 +189,15 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
|||
}
|
||||
|
||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
|
||||
src = N.UnwrapPacketReader(src)
|
||||
dst = N.UnwrapPacketWriter(dst)
|
||||
safeSrc := N.IsSafePacketReader(src)
|
||||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
headroom := frontHeadroom + rearHeadroom
|
||||
if srcUnsafe {
|
||||
if safeSrc != nil {
|
||||
if headroom == 0 {
|
||||
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
|
||||
return CopyPacketWithSrcBuffer(dst, safeSrc)
|
||||
}
|
||||
}
|
||||
if N.IsUnsafeWriter(dst) {
|
||||
|
|
|
@ -77,12 +77,16 @@ type CachedReader interface {
|
|||
ReadCached() *buf.Buffer
|
||||
}
|
||||
|
||||
type CachedPacketReader interface {
|
||||
ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer)
|
||||
}
|
||||
|
||||
type WithUpstreamReader interface {
|
||||
UpstreamReader() io.Reader
|
||||
UpstreamReader() any
|
||||
}
|
||||
|
||||
type WithUpstreamWriter interface {
|
||||
UpstreamWriter() io.Writer
|
||||
UpstreamWriter() any
|
||||
}
|
||||
|
||||
type ReaderWithUpstream interface {
|
||||
|
@ -98,7 +102,7 @@ func UnwrapReader(reader io.Reader) io.Reader {
|
|||
return reader
|
||||
}
|
||||
if u, ok := reader.(WithUpstreamReader); ok {
|
||||
return UnwrapReader(u.UpstreamReader())
|
||||
return UnwrapReader(u.UpstreamReader().(io.Reader))
|
||||
}
|
||||
if u, ok := reader.(common.WithUpstream); ok {
|
||||
return UnwrapReader(u.Upstream().(io.Reader))
|
||||
|
@ -106,15 +110,41 @@ func UnwrapReader(reader io.Reader) io.Reader {
|
|||
panic("bad reader")
|
||||
}
|
||||
|
||||
func UnwrapPacketReader(reader PacketReader) PacketReader {
|
||||
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
|
||||
return reader
|
||||
}
|
||||
if u, ok := reader.(WithUpstreamReader); ok {
|
||||
return UnwrapPacketReader(u.UpstreamReader().(PacketReader))
|
||||
}
|
||||
if u, ok := reader.(common.WithUpstream); ok {
|
||||
return UnwrapPacketReader(u.Upstream().(PacketReader))
|
||||
}
|
||||
panic("bad reader")
|
||||
}
|
||||
|
||||
func UnwrapWriter(writer io.Writer) io.Writer {
|
||||
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
|
||||
return writer
|
||||
}
|
||||
if u, ok := writer.(WithUpstreamWriter); ok {
|
||||
return UnwrapWriter(u.UpstreamWriter())
|
||||
return UnwrapWriter(u.UpstreamWriter().(io.Writer))
|
||||
}
|
||||
if u, ok := writer.(common.WithUpstream); ok {
|
||||
return UnwrapWriter(u.Upstream().(io.Writer))
|
||||
}
|
||||
panic("bad writer")
|
||||
}
|
||||
|
||||
func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
|
||||
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
|
||||
return writer
|
||||
}
|
||||
if u, ok := writer.(WithUpstreamWriter); ok {
|
||||
return UnwrapPacketWriter(u.UpstreamWriter().(PacketWriter))
|
||||
}
|
||||
if u, ok := writer.(common.WithUpstream); ok {
|
||||
return UnwrapPacketWriter(u.Upstream().(PacketWriter))
|
||||
}
|
||||
panic("bad writer")
|
||||
}
|
||||
|
|
|
@ -23,6 +23,38 @@ func IsUnsafeWriter(writer any) bool {
|
|||
return isUnsafe
|
||||
}
|
||||
|
||||
func IsSafeReader(reader any) ThreadSafeReader {
|
||||
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
|
||||
return safeReader
|
||||
}
|
||||
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
|
||||
return nil
|
||||
}
|
||||
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
|
||||
return IsSafeReader(upstream.Upstream())
|
||||
}
|
||||
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
|
||||
return IsSafeReader(upstream.UpstreamReader())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsSafePacketReader(reader any) ThreadSafePacketReader {
|
||||
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
|
||||
return safeReader
|
||||
}
|
||||
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
|
||||
return nil
|
||||
}
|
||||
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
|
||||
return IsSafePacketReader(upstream.Upstream())
|
||||
}
|
||||
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
|
||||
return IsSafePacketReader(upstream.UpstreamReader())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FrontHeadroom interface {
|
||||
FrontHeadroom() int
|
||||
}
|
||||
|
@ -53,7 +85,6 @@ func CalculateRearHeadroom(writer any) int {
|
|||
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
|
||||
headroom += CalculateRearHeadroom(upstream.Upstream())
|
||||
}
|
||||
|
||||
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
|
||||
headroom += CalculateRearHeadroom(upstream.UpstreamWriter())
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue