sing/common/bufio/conn.go
2022-06-28 22:16:48 +08:00

499 lines
11 KiB
Go

package bufio
import (
"context"
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/task"
)
type readOnlyReader struct {
io.Reader
}
func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.Reader)
}
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
_, isTCPConn := dst.(*net.TCPConn)
if !isTCPConn {
return false
}
switch src.(type) {
case *net.TCPConn, *net.UnixConn, *os.File:
return false
default:
return true
}
}
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
if rt, ok := dst.(io.ReaderFrom); ok {
if needReadFromWrapper(rt, src) {
src = &readOnlyReader{src}
}
return rt.ReadFrom(src)
}
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
}
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafeReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
return CopyExtendedWithSrcBuffer(dst, unsafeSrc)
} else if dstUnsafe {
return CopyExtendedWithPool(dst, src)
}
_buffer := buf.StackNew()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
return CopyExtendedBuffer(dst, src, buffer)
}
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
for {
buffer.Reset()
err = src.ReadBuffer(buffer)
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
return
}
n += int64(dataLen)
}
}
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
for {
var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe()
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
for {
buffer := buf.New()
err = src.ReadBuffer(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
err = dst.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
defer common.Close(conn, dest)
err := task.Run(ctx, func() error {
defer rw.CloseRead(conn)
defer rw.CloseWrite(dest)
return common.Error(Copy(dest, conn))
}, func() error {
defer rw.CloseRead(dest)
defer rw.CloseWrite(conn)
return common.Error(Copy(conn, dest))
})
return err
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
return CopyPacketWithSrcBuffer(dst, unsafeSrc)
} else if dstUnsafe {
return CopyPacketWithPool(dst, src)
}
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
for {
buffer.Reset()
destination, err = src.ReadPacket(buffer)
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
}
}
func CopyPacketTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
return CopyPacketWithSrcBufferTimeout(dst, unsafeSrc, src, timeout)
} else if dstUnsafe {
return CopyPacketWithPoolTimeout(dst, src, timeout)
}
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
for {
buffer.Reset()
err = src.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return
}
destination, err = src.ReadPacket(buffer)
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
}
}
func CopyPacketWithSrcBuffer(dest N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
for {
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
return
}
dataLen := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyPacketWithSrcBufferTimeout(dest N.PacketWriter, src N.ThreadSafePacketReader, tSrc N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
for {
err = tSrc.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return
}
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
return
}
dataLen := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyPacketWithPool(dest N.PacketWriter, src N.PacketReader) (n int64, err error) {
var destination M.Socksaddr
for {
buffer := buf.NewPacket()
destination, err = src.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyPacketWithPoolTimeout(dest N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
var destination M.Socksaddr
for {
buffer := buf.NewPacket()
err = src.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return
}
destination, err = src.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
err = dest.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
}
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
defer common.Close(conn, dest)
return task.Any(ctx, func() error {
return common.Error(CopyPacket(dest, conn))
}, func() error {
return common.Error(CopyPacket(conn, dest))
})
}
func CopyPacketConnTimeout(ctx context.Context, conn N.PacketConn, dest N.PacketConn, timeout time.Duration) error {
defer common.Close(conn, dest)
return task.Any(ctx, func() error {
return common.Error(CopyPacketTimeout(dest, conn, timeout))
}, func() error {
return common.Error(CopyPacketTimeout(conn, dest, timeout))
})
}
func NewPacketConn(conn net.PacketConn) N.NetPacketConn {
if packetConn, ok := conn.(N.NetPacketConn); ok {
return packetConn
} else if udpConn, ok := conn.(*net.UDPConn); ok {
return &ExtendedUDPConn{udpConn}
} else {
return &ExtendedPacketConn{conn}
}
}
type ExtendedUDPConn struct {
*net.UDPConn
}
func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
n, addr, err := w.ReadFromUDPAddrPort(buffer.FreeBytes())
if err != nil {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
return M.SocksaddrFromNetIP(addr), nil
}
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
if destination.Family().IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
}
func (w *ExtendedUDPConn) Upstream() any {
return w.UDPConn
}
type ExtendedPacketConn struct {
net.PacketConn
}
func (w *ExtendedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
_, addr, err := buffer.ReadPacketFrom(w)
if err != nil {
return M.Socksaddr{}, err
}
return M.SocksaddrFromNet(addr), err
}
func (w *ExtendedPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
return common.Error(w.WriteTo(buffer.Bytes(), destination.UDPAddr()))
}
func (w *ExtendedPacketConn) Upstream() any {
return w.PacketConn
}
type BindPacketConn struct {
net.PacketConn
Addr net.Addr
}
func (c *BindPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *BindPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.Addr)
}
func (c *BindPacketConn) RemoteAddr() net.Addr {
return c.Addr
}
func (c *BindPacketConn) Upstream() any {
return c.PacketConn
}
type UnbindPacketConn struct {
N.ExtendedConn
Addr M.Socksaddr
}
func (c *UnbindPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
if err == nil {
addr = c.Addr.UDPAddr()
}
return
}
func (c *UnbindPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
return c.ExtendedConn.Write(p)
}
func (c *UnbindPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
err = c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return
}
destination = c.Addr
return
}
func (c *UnbindPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewUnbindPacketConn(conn net.Conn) N.NetPacketConn {
return &UnbindPacketConn{
NewExtendedConn(conn),
M.SocksaddrFromNet(conn.RemoteAddr()),
}
}
type ExtendedReaderWrapper struct {
io.Reader
}
func (r *ExtendedReaderWrapper) ReadBuffer(buffer *buf.Buffer) error {
n, err := r.Read(buffer.FreeBytes())
if err != nil {
return err
}
buffer.Truncate(n)
return nil
}
func (r *ExtendedReaderWrapper) Upstream() any {
return r.Reader
}
func (r *ExtendedReaderWrapper) ReaderReplaceable() bool {
return true
}
func NewExtendedReader(reader io.Reader) N.ExtendedReader {
if r, ok := reader.(N.ExtendedReader); ok {
return r
}
return &ExtendedReaderWrapper{reader}
}
type ExtendedWriterWrapper struct {
io.Writer
}
func (w *ExtendedWriterWrapper) WriteBuffer(buffer *buf.Buffer) error {
return common.Error(w.Write(buffer.Bytes()))
}
func (w *ExtendedWriterWrapper) Upstream() any {
return w.Writer
}
func (w *ExtendedReaderWrapper) WriterReplaceable() bool {
return true
}
func NewExtendedWriter(writer io.Writer) N.ExtendedWriter {
if w, ok := writer.(N.ExtendedWriter); ok {
return w
}
return &ExtendedWriterWrapper{writer}
}
type ExtendedConnWrapper struct {
net.Conn
reader N.ExtendedReader
writer N.ExtendedWriter
}
func (w *ExtendedConnWrapper) ReadBuffer(buffer *buf.Buffer) error {
return w.reader.ReadBuffer(buffer)
}
func (w *ExtendedConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
return w.writer.WriteBuffer(buffer)
}
func NewExtendedConn(conn net.Conn) N.ExtendedConn {
if c, ok := conn.(N.ExtendedConn); ok {
return c
}
return &ExtendedConnWrapper{
Conn: conn,
reader: NewExtendedReader(conn),
writer: NewExtendedWriter(conn),
}
}