hysteria/extras/obfs/conn.go
2023-08-30 20:02:18 -07:00

121 lines
2.8 KiB
Go

package obfs
import (
"net"
"sync"
"syscall"
"time"
)
const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods.
// Both methods return the number of bytes written to out.
// If a packet is not valid, the methods should return 0.
type Obfuscator interface {
Obfuscate(in, out []byte) int
Deobfuscate(in, out []byte) int
}
var _ net.PacketConn = (*obfsPacketConn)(nil)
type obfsPacketConn struct {
Conn net.PacketConn
Obfs Obfuscator
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn
// as the underlying connection. We pass additional methods to quic-go to
// enable UDP-specific optimizations.
type obfsPacketConnUDP struct {
*obfsPacketConn
UDPConn *net.UDPConn
}
// WrapPacketConn enables obfuscation on a net.PacketConn.
// The obfuscation is transparent to the caller - the n bytes returned by
// ReadFrom and WriteTo are the number of original bytes, not after
// obfuscation/deobfuscation.
func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn {
opc := &obfsPacketConn{
Conn: conn,
Obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
}
if udpConn, ok := conn.(*net.UDPConn); ok {
return &obfsPacketConnUDP{
obfsPacketConn: opc,
UDPConn: udpConn,
}
} else {
return opc
}
}
func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
for {
c.readMutex.Lock()
n, addr, err = c.Conn.ReadFrom(c.readBuf)
if n <= 0 {
c.readMutex.Unlock()
return
}
n = c.Obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
if n > 0 || err != nil {
return
}
// Invalid packet, try again
}
}
func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
nn := c.Obfs.Obfuscate(p, c.writeBuf)
_, err = c.Conn.WriteTo(c.writeBuf[:nn], addr)
c.writeMutex.Unlock()
if err == nil {
n = len(p)
}
return
}
func (c *obfsPacketConn) Close() error {
return c.Conn.Close()
}
func (c *obfsPacketConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *obfsPacketConn) SetDeadline(t time.Time) error {
return c.Conn.SetDeadline(t)
}
func (c *obfsPacketConn) SetReadDeadline(t time.Time) error {
return c.Conn.SetReadDeadline(t)
}
func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
// UDP-specific methods below
func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error {
return c.UDPConn.SetReadBuffer(bytes)
}
func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error {
return c.UDPConn.SetWriteBuffer(bytes)
}
func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) {
return c.UDPConn.SyscallConn()
}