Fix copy with src buffer

This commit is contained in:
世界 2022-08-21 08:40:38 +08:00
parent b5294f57ba
commit 5945fd0457
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 82 additions and 13 deletions

View file

@ -164,6 +164,12 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
return c.PacketConn.ReadPacket(buffer) return c.PacketConn.ReadPacket(buffer)
} }
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
buffer = c.buffer
c.buffer = nil
return c.destination, buffer
}
func (c *CachedPacketConn) Upstream() any { func (c *CachedPacketConn) Upstream() any {
return c.PacketConn return c.PacketConn
} }

View file

@ -215,7 +215,7 @@ func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.reader) return Copy(w, r.reader)
} }
func (w *ExtendedConnWrapper) UpstreamReader() io.Reader { func (w *ExtendedConnWrapper) UpstreamReader() any {
return w.reader return w.reader
} }
@ -223,7 +223,7 @@ func (w *ExtendedConnWrapper) ReaderReplaceable() bool {
return true return true
} }
func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer { func (w *ExtendedConnWrapper) UpstreamWriter() any {
return w.writer return w.writer
} }

View file

@ -57,11 +57,11 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
} }
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src) safeSrc := N.IsSafeReader(src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst) headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if srcUnsafe { if safeSrc != nil {
if headroom == 0 { if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, unsafeSrc) return CopyExtendedWithSrcBuffer(dst, safeSrc)
} }
} }
if N.IsUnsafeWriter(dst) { if N.IsUnsafeWriter(dst) {
@ -189,13 +189,15 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
} }
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src) src = N.UnwrapPacketReader(src)
dst = N.UnwrapPacketWriter(dst)
safeSrc := N.IsSafePacketReader(src)
frontHeadroom := N.CalculateFrontHeadroom(dst) frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst) rearHeadroom := N.CalculateRearHeadroom(dst)
headroom := frontHeadroom + rearHeadroom headroom := frontHeadroom + rearHeadroom
if srcUnsafe { if safeSrc != nil {
if headroom == 0 { if headroom == 0 {
return CopyPacketWithSrcBuffer(dst, unsafeSrc) return CopyPacketWithSrcBuffer(dst, safeSrc)
} }
} }
if N.IsUnsafeWriter(dst) { if N.IsUnsafeWriter(dst) {

View file

@ -77,12 +77,16 @@ type CachedReader interface {
ReadCached() *buf.Buffer ReadCached() *buf.Buffer
} }
type CachedPacketReader interface {
ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer)
}
type WithUpstreamReader interface { type WithUpstreamReader interface {
UpstreamReader() io.Reader UpstreamReader() any
} }
type WithUpstreamWriter interface { type WithUpstreamWriter interface {
UpstreamWriter() io.Writer UpstreamWriter() any
} }
type ReaderWithUpstream interface { type ReaderWithUpstream interface {
@ -98,7 +102,7 @@ func UnwrapReader(reader io.Reader) io.Reader {
return reader return reader
} }
if u, ok := reader.(WithUpstreamReader); ok { if u, ok := reader.(WithUpstreamReader); ok {
return UnwrapReader(u.UpstreamReader()) return UnwrapReader(u.UpstreamReader().(io.Reader))
} }
if u, ok := reader.(common.WithUpstream); ok { if u, ok := reader.(common.WithUpstream); ok {
return UnwrapReader(u.Upstream().(io.Reader)) return UnwrapReader(u.Upstream().(io.Reader))
@ -106,15 +110,41 @@ func UnwrapReader(reader io.Reader) io.Reader {
panic("bad reader") panic("bad reader")
} }
func UnwrapPacketReader(reader PacketReader) PacketReader {
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return reader
}
if u, ok := reader.(WithUpstreamReader); ok {
return UnwrapPacketReader(u.UpstreamReader().(PacketReader))
}
if u, ok := reader.(common.WithUpstream); ok {
return UnwrapPacketReader(u.Upstream().(PacketReader))
}
panic("bad reader")
}
func UnwrapWriter(writer io.Writer) io.Writer { func UnwrapWriter(writer io.Writer) io.Writer {
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer return writer
} }
if u, ok := writer.(WithUpstreamWriter); ok { if u, ok := writer.(WithUpstreamWriter); ok {
return UnwrapWriter(u.UpstreamWriter()) return UnwrapWriter(u.UpstreamWriter().(io.Writer))
} }
if u, ok := writer.(common.WithUpstream); ok { if u, ok := writer.(common.WithUpstream); ok {
return UnwrapWriter(u.Upstream().(io.Writer)) return UnwrapWriter(u.Upstream().(io.Writer))
} }
panic("bad writer") panic("bad writer")
} }
func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer
}
if u, ok := writer.(WithUpstreamWriter); ok {
return UnwrapPacketWriter(u.UpstreamWriter().(PacketWriter))
}
if u, ok := writer.(common.WithUpstream); ok {
return UnwrapPacketWriter(u.Upstream().(PacketWriter))
}
panic("bad writer")
}

View file

@ -23,6 +23,38 @@ func IsUnsafeWriter(writer any) bool {
return isUnsafe return isUnsafe
} }
func IsSafeReader(reader any) ThreadSafeReader {
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
return safeReader
}
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
return nil
}
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
return IsSafeReader(upstream.Upstream())
}
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
return IsSafeReader(upstream.UpstreamReader())
}
return nil
}
func IsSafePacketReader(reader any) ThreadSafePacketReader {
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
return safeReader
}
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
return nil
}
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
return IsSafePacketReader(upstream.Upstream())
}
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
return IsSafePacketReader(upstream.UpstreamReader())
}
return nil
}
type FrontHeadroom interface { type FrontHeadroom interface {
FrontHeadroom() int FrontHeadroom() int
} }
@ -53,7 +85,6 @@ func CalculateRearHeadroom(writer any) int {
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream { if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
headroom += CalculateRearHeadroom(upstream.Upstream()) headroom += CalculateRearHeadroom(upstream.Upstream())
} }
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream { if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
headroom += CalculateRearHeadroom(upstream.UpstreamWriter()) headroom += CalculateRearHeadroom(upstream.UpstreamWriter())
} }