package tproxy import ( "errors" "net" "time" "github.com/apernet/go-tproxy" "github.com/apernet/hysteria/core/v2/client" ) const ( udpBufferSize = 4096 defaultTimeout = 60 * time.Second ) type UDPTProxy struct { HyClient client.Client Timeout time.Duration EventLogger UDPEventLogger } type UDPEventLogger interface { Connect(addr, reqAddr net.Addr) Error(addr, reqAddr net.Addr, err error) } func (r *UDPTProxy) ListenAndServe(laddr *net.UDPAddr) error { conn, err := tproxy.ListenUDP("udp", laddr) if err != nil { return err } defer conn.Close() buf := make([]byte, udpBufferSize) for { // We will only get the first packet of each src/dst pair here, // because newPair will create a TProxy connection and take over // the src/dst pair. Later packets will be sent there instead of here. n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf) if err != nil { return err } r.newPair(srcAddr, dstAddr, buf[:n]) } } func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) { if r.EventLogger != nil { r.EventLogger.Connect(srcAddr, dstAddr) } var closeErr error defer func() { // If closeErr is nil, it means we at least successfully sent the first packet // and started forwarding, in which case we don't call the error logger. if r.EventLogger != nil && closeErr != nil { r.EventLogger.Error(srcAddr, dstAddr, closeErr) } }() conn, err := tproxy.DialUDP("udp", dstAddr, srcAddr) if err != nil { closeErr = err return } hyConn, err := r.HyClient.UDP() if err != nil { _ = conn.Close() closeErr = err return } // Send the first packet err = hyConn.Send(initPkt, dstAddr.String()) if err != nil { _ = conn.Close() _ = hyConn.Close() closeErr = err return } // Start forwarding go func() { err := r.forwarding(conn, hyConn, dstAddr.String()) _ = conn.Close() _ = hyConn.Close() if r.EventLogger != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { // We don't consider deadline exceeded (timeout) an error err = nil } r.EventLogger.Error(srcAddr, dstAddr, err) } }() } func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error { errChan := make(chan error, 2) // Local <- Remote go func() { for { bs, _, err := hyConn.Receive() if err != nil { errChan <- err return } _, err = conn.Write(bs) if err != nil { errChan <- err return } _ = r.updateConnDeadline(conn) } }() // Local -> Remote go func() { buf := make([]byte, udpBufferSize) for { _ = r.updateConnDeadline(conn) n, err := conn.Read(buf) if n > 0 { err := hyConn.Send(buf[:n], dst) if err != nil { errChan <- err return } } if err != nil { errChan <- err return } } }() return <-errChan } func (r *UDPTProxy) updateConnDeadline(conn *net.UDPConn) error { if r.Timeout == 0 { return conn.SetReadDeadline(time.Now().Add(defaultTimeout)) } else { return conn.SetReadDeadline(time.Now().Add(r.Timeout)) } }