Fix copy with src buffer

This commit is contained in:
世界 2022-08-21 08:40:38 +08:00
parent b5294f57ba
commit 5945fd0457
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 82 additions and 13 deletions

View file

@ -164,6 +164,12 @@ func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
return c.PacketConn.ReadPacket(buffer)
}
func (c *CachedPacketConn) ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer) {
buffer = c.buffer
c.buffer = nil
return c.destination, buffer
}
func (c *CachedPacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -215,7 +215,7 @@ func (r *ExtendedConnWrapper) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.reader)
}
func (w *ExtendedConnWrapper) UpstreamReader() io.Reader {
func (w *ExtendedConnWrapper) UpstreamReader() any {
return w.reader
}
@ -223,7 +223,7 @@ func (w *ExtendedConnWrapper) ReaderReplaceable() bool {
return true
}
func (w *ExtendedConnWrapper) UpstreamWriter() io.Writer {
func (w *ExtendedConnWrapper) UpstreamWriter() any {
return w.writer
}

View file

@ -57,11 +57,11 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
}
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
safeSrc := N.IsSafeReader(src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if srcUnsafe {
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, unsafeSrc)
return CopyExtendedWithSrcBuffer(dst, safeSrc)
}
}
if N.IsUnsafeWriter(dst) {
@ -189,13 +189,15 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
src = N.UnwrapPacketReader(src)
dst = N.UnwrapPacketWriter(dst)
safeSrc := N.IsSafePacketReader(src)
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
headroom := frontHeadroom + rearHeadroom
if srcUnsafe {
if safeSrc != nil {
if headroom == 0 {
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
return CopyPacketWithSrcBuffer(dst, safeSrc)
}
}
if N.IsUnsafeWriter(dst) {

View file

@ -77,12 +77,16 @@ type CachedReader interface {
ReadCached() *buf.Buffer
}
type CachedPacketReader interface {
ReadCachedPacket() (destination M.Socksaddr, buffer *buf.Buffer)
}
type WithUpstreamReader interface {
UpstreamReader() io.Reader
UpstreamReader() any
}
type WithUpstreamWriter interface {
UpstreamWriter() io.Writer
UpstreamWriter() any
}
type ReaderWithUpstream interface {
@ -98,7 +102,7 @@ func UnwrapReader(reader io.Reader) io.Reader {
return reader
}
if u, ok := reader.(WithUpstreamReader); ok {
return UnwrapReader(u.UpstreamReader())
return UnwrapReader(u.UpstreamReader().(io.Reader))
}
if u, ok := reader.(common.WithUpstream); ok {
return UnwrapReader(u.Upstream().(io.Reader))
@ -106,15 +110,41 @@ func UnwrapReader(reader io.Reader) io.Reader {
panic("bad reader")
}
func UnwrapPacketReader(reader PacketReader) PacketReader {
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
return reader
}
if u, ok := reader.(WithUpstreamReader); ok {
return UnwrapPacketReader(u.UpstreamReader().(PacketReader))
}
if u, ok := reader.(common.WithUpstream); ok {
return UnwrapPacketReader(u.Upstream().(PacketReader))
}
panic("bad reader")
}
func UnwrapWriter(writer io.Writer) io.Writer {
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer
}
if u, ok := writer.(WithUpstreamWriter); ok {
return UnwrapWriter(u.UpstreamWriter())
return UnwrapWriter(u.UpstreamWriter().(io.Writer))
}
if u, ok := writer.(common.WithUpstream); ok {
return UnwrapWriter(u.Upstream().(io.Writer))
}
panic("bad writer")
}
func UnwrapPacketWriter(writer PacketWriter) PacketWriter {
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
return writer
}
if u, ok := writer.(WithUpstreamWriter); ok {
return UnwrapPacketWriter(u.UpstreamWriter().(PacketWriter))
}
if u, ok := writer.(common.WithUpstream); ok {
return UnwrapPacketWriter(u.Upstream().(PacketWriter))
}
panic("bad writer")
}

View file

@ -23,6 +23,38 @@ func IsUnsafeWriter(writer any) bool {
return isUnsafe
}
func IsSafeReader(reader any) ThreadSafeReader {
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
return safeReader
}
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
return nil
}
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
return IsSafeReader(upstream.Upstream())
}
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
return IsSafeReader(upstream.UpstreamReader())
}
return nil
}
func IsSafePacketReader(reader any) ThreadSafePacketReader {
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
return safeReader
}
if upstream, hasUpstream := reader.(ReaderWithUpstream); !hasUpstream || !upstream.ReaderReplaceable() {
return nil
}
if upstream, hasUpstream := reader.(common.WithUpstream); hasUpstream {
return IsSafePacketReader(upstream.Upstream())
}
if upstream, hasUpstream := reader.(WithUpstreamReader); hasUpstream {
return IsSafePacketReader(upstream.UpstreamReader())
}
return nil
}
type FrontHeadroom interface {
FrontHeadroom() int
}
@ -53,7 +85,6 @@ func CalculateRearHeadroom(writer any) int {
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
headroom += CalculateRearHeadroom(upstream.Upstream())
}
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
headroom += CalculateRearHeadroom(upstream.UpstreamWriter())
}