mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
140 lines
3 KiB
Go
140 lines
3 KiB
Go
package tproxy
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/apernet/go-tproxy"
|
|
"github.com/apernet/hysteria/core/v2/client"
|
|
)
|
|
|
|
const (
|
|
udpBufferSize = 4096
|
|
defaultTimeout = 60 * time.Second
|
|
)
|
|
|
|
type UDPTProxy struct {
|
|
HyClient client.Client
|
|
Timeout time.Duration
|
|
EventLogger UDPEventLogger
|
|
}
|
|
|
|
type UDPEventLogger interface {
|
|
Connect(addr, reqAddr net.Addr)
|
|
Error(addr, reqAddr net.Addr, err error)
|
|
}
|
|
|
|
func (r *UDPTProxy) ListenAndServe(laddr *net.UDPAddr) error {
|
|
conn, err := tproxy.ListenUDP("udp", laddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
buf := make([]byte, udpBufferSize)
|
|
for {
|
|
// We will only get the first packet of each src/dst pair here,
|
|
// because newPair will create a TProxy connection and take over
|
|
// the src/dst pair. Later packets will be sent there instead of here.
|
|
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.newPair(srcAddr, dstAddr, buf[:n])
|
|
}
|
|
}
|
|
|
|
func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
|
|
if r.EventLogger != nil {
|
|
r.EventLogger.Connect(srcAddr, dstAddr)
|
|
}
|
|
var closeErr error
|
|
defer func() {
|
|
// If closeErr is nil, it means we at least successfully sent the first packet
|
|
// and started forwarding, in which case we don't call the error logger.
|
|
if r.EventLogger != nil && closeErr != nil {
|
|
r.EventLogger.Error(srcAddr, dstAddr, closeErr)
|
|
}
|
|
}()
|
|
conn, err := tproxy.DialUDP("udp", dstAddr, srcAddr)
|
|
if err != nil {
|
|
closeErr = err
|
|
return
|
|
}
|
|
hyConn, err := r.HyClient.UDP()
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
closeErr = err
|
|
return
|
|
}
|
|
// Send the first packet
|
|
err = hyConn.Send(initPkt, dstAddr.String())
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
_ = hyConn.Close()
|
|
closeErr = err
|
|
return
|
|
}
|
|
// Start forwarding
|
|
go func() {
|
|
err := r.forwarding(conn, hyConn, dstAddr.String())
|
|
_ = conn.Close()
|
|
_ = hyConn.Close()
|
|
if r.EventLogger != nil {
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
// We don't consider deadline exceeded (timeout) an error
|
|
err = nil
|
|
}
|
|
r.EventLogger.Error(srcAddr, dstAddr, err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error {
|
|
errChan := make(chan error, 2)
|
|
// Local <- Remote
|
|
go func() {
|
|
for {
|
|
bs, _, err := hyConn.Receive()
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
_, err = conn.Write(bs)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
_ = r.updateConnDeadline(conn)
|
|
}
|
|
}()
|
|
// Local -> Remote
|
|
go func() {
|
|
buf := make([]byte, udpBufferSize)
|
|
for {
|
|
_ = r.updateConnDeadline(conn)
|
|
n, err := conn.Read(buf)
|
|
if n > 0 {
|
|
err := hyConn.Send(buf[:n], dst)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
}
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return <-errChan
|
|
}
|
|
|
|
func (r *UDPTProxy) updateConnDeadline(conn *net.UDPConn) error {
|
|
if r.Timeout == 0 {
|
|
return conn.SetReadDeadline(time.Now().Add(defaultTimeout))
|
|
} else {
|
|
return conn.SetReadDeadline(time.Now().Add(r.Timeout))
|
|
}
|
|
}
|