sing/common/rw/copy.go
2022-04-11 12:46:23 +08:00

87 lines
1.8 KiB
Go

package rw
import (
"context"
"io"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/task"
)
func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
writer := *writerVar
writerBack := writer
for {
if w, ok := writer.(io.ReaderFrom); ok {
return w.ReadFrom(reader)
}
if f, ok := writer.(common.Flusher); ok {
err := f.Flush()
if err != nil {
return 0, err
}
}
if u, ok := writer.(common.WriterWithUpstream); ok {
if u.Replaceable() && writerBack == writer {
writer = u.Upstream()
writerBack = writer
writerVar = &writer
continue
}
writer = u.Upstream()
writerBack = writer
} else {
break
}
}
return 0, os.ErrInvalid
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
return task.Run(context.Background(), func() error {
defer CloseRead(conn)
defer CloseWrite(dest)
return common.Error(io.Copy(dest, conn))
}, func() error {
defer CloseRead(dest)
defer CloseWrite(conn)
return common.Error(io.Copy(conn, dest))
})
}
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
return task.Run(ctx, func() error {
buffer := buf.FullNew()
defer buffer.Release()
for {
n, addr, err := conn.ReadFrom(buffer.FreeBytes())
if err != nil {
return err
}
buffer.Truncate(n)
_, err = outPacketConn.WriteTo(buffer.Bytes(), addr)
if err != nil {
return err
}
buffer.FullReset()
}
}, func() error {
buffer := buf.FullNew()
defer buffer.Release()
for {
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {
return err
}
buffer.Truncate(n)
_, err = conn.WriteTo(buffer.Bytes(), addr)
if err != nil {
return err
}
buffer.FullReset()
}
})
}