sing/common/uot/conn.go
2023-03-15 13:10:16 +08:00

144 lines
3.3 KiB
Go

package uot
import (
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type Conn struct {
net.Conn
isConnect bool
destination M.Socksaddr
writer N.VectorisedWriter
}
func NewConn(conn net.Conn, isConnect bool, destination M.Socksaddr) *Conn {
uConn := &Conn{
Conn: conn,
isConnect: isConnect,
destination: destination,
}
uConn.writer, _ = bufio.CreateVectorisedWriter(conn)
return uConn
}
func (c *Conn) Read(p []byte) (n int, err error) {
n, _, err = c.ReadFrom(p)
return
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.WriteTo(p, c.destination)
}
func (c *Conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
var destination M.Socksaddr
if c.isConnect {
destination = c.destination
} else {
destination, err = AddrParser.ReadAddrPort(c.Conn)
if err != nil {
return
}
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return
}
if len(p) < int(length) {
err = E.Cause(io.ErrShortBuffer, "UoT read")
return
}
n, err = c.Conn.Read(p[:length])
if err == nil {
addr = destination.UDPAddr()
}
return
}
func (c *Conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
var bufferLen int
if !c.isConnect {
bufferLen += AddrParser.AddrPortLen(destination)
}
bufferLen += 2
if c.writer == nil {
bufferLen += len(p)
}
_buffer := buf.StackNewSize(bufferLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
if !c.isConnect {
common.Must(AddrParser.WriteAddrPort(buffer, destination))
}
common.Must(binary.Write(buffer, binary.BigEndian, uint16(len(p))))
if c.writer == nil {
common.Must1(buffer.Write(p))
return c.Conn.Write(buffer.Bytes())
}
err = c.writer.WriteVectorised([]*buf.Buffer{buffer, buf.As(p)})
if err == nil {
n = len(p)
}
return
}
func (c *Conn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if c.isConnect {
destination = c.destination
} else {
destination, err = AddrParser.ReadAddrPort(c.Conn)
if err != nil {
return
}
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
return M.Socksaddr{}, E.Cause(err, "UoT read")
}
return
}
func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
var headerLen int
if !c.isConnect {
headerLen += AddrParser.AddrPortLen(destination)
}
headerLen += 2
if c.writer == nil {
headerLen += buffer.Len()
}
_header := buf.StackNewSize(headerLen)
defer common.KeepAlive(_header)
header := common.Dup(_header)
defer header.Release()
if !c.isConnect {
common.Must(AddrParser.WriteAddrPort(header, destination))
}
common.Must(binary.Write(header, binary.BigEndian, uint16(buffer.Len())))
if c.writer == nil {
common.Must1(header.Write(buffer.Bytes()))
return common.Error(c.Conn.Write(header.Bytes()))
}
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
}
func (c *Conn) Upstream() any {
return c.Conn
}