hysteria/app/internal/tproxy/udp_linux.go

140 lines
3 KiB
Go

package tproxy
import (
"errors"
"net"
"time"
"github.com/apernet/go-tproxy"
"github.com/apernet/hysteria/core/v2/client"
)
const (
udpBufferSize = 4096
defaultTimeout = 60 * time.Second
)
type UDPTProxy struct {
HyClient client.Client
Timeout time.Duration
EventLogger UDPEventLogger
}
type UDPEventLogger interface {
Connect(addr, reqAddr net.Addr)
Error(addr, reqAddr net.Addr, err error)
}
func (r *UDPTProxy) ListenAndServe(laddr *net.UDPAddr) error {
conn, err := tproxy.ListenUDP("udp", laddr)
if err != nil {
return err
}
defer conn.Close()
buf := make([]byte, udpBufferSize)
for {
// We will only get the first packet of each src/dst pair here,
// because newPair will create a TProxy connection and take over
// the src/dst pair. Later packets will be sent there instead of here.
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf)
if err != nil {
return err
}
r.newPair(srcAddr, dstAddr, buf[:n])
}
}
func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
if r.EventLogger != nil {
r.EventLogger.Connect(srcAddr, dstAddr)
}
var closeErr error
defer func() {
// If closeErr is nil, it means we at least successfully sent the first packet
// and started forwarding, in which case we don't call the error logger.
if r.EventLogger != nil && closeErr != nil {
r.EventLogger.Error(srcAddr, dstAddr, closeErr)
}
}()
conn, err := tproxy.DialUDP("udp", dstAddr, srcAddr)
if err != nil {
closeErr = err
return
}
hyConn, err := r.HyClient.UDP()
if err != nil {
_ = conn.Close()
closeErr = err
return
}
// Send the first packet
err = hyConn.Send(initPkt, dstAddr.String())
if err != nil {
_ = conn.Close()
_ = hyConn.Close()
closeErr = err
return
}
// Start forwarding
go func() {
err := r.forwarding(conn, hyConn, dstAddr.String())
_ = conn.Close()
_ = hyConn.Close()
if r.EventLogger != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// We don't consider deadline exceeded (timeout) an error
err = nil
}
r.EventLogger.Error(srcAddr, dstAddr, err)
}
}()
}
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error {
errChan := make(chan error, 2)
// Local <- Remote
go func() {
for {
bs, _, err := hyConn.Receive()
if err != nil {
errChan <- err
return
}
_, err = conn.Write(bs)
if err != nil {
errChan <- err
return
}
_ = r.updateConnDeadline(conn)
}
}()
// Local -> Remote
go func() {
buf := make([]byte, udpBufferSize)
for {
_ = r.updateConnDeadline(conn)
n, err := conn.Read(buf)
if n > 0 {
err := hyConn.Send(buf[:n], dst)
if err != nil {
errChan <- err
return
}
}
if err != nil {
errChan <- err
return
}
}
}()
return <-errChan
}
func (r *UDPTProxy) updateConnDeadline(conn *net.UDPConn) error {
if r.Timeout == 0 {
return conn.SetReadDeadline(time.Now().Add(defaultTimeout))
} else {
return conn.SetReadDeadline(time.Now().Add(r.Timeout))
}
}