From f61b272cbf3732ac7d8307ee787963ba78ca5945 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Wed, 20 Mar 2024 10:46:54 +0800
Subject: [PATCH] Fix WireGuard client bind

---
 transport/wireguard/client_bind.go | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go
index d39d1bec..39adce25 100644
--- a/transport/wireguard/client_bind.go
+++ b/transport/wireguard/client_bind.go
@@ -36,6 +36,7 @@ func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer,
 		errorHandler:        errorHandler,
 		dialer:              dialer,
 		reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
+		done:                make(chan struct{}),
 		isConnect:           isConnect,
 		connectAddr:         connectAddr,
 		reserved:            reserved,
@@ -88,8 +89,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
 func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
 	select {
 	case <-c.done:
-		err = net.ErrClosed
-		return
+		c.done = make(chan struct{})
 	default:
 	}
 	return []conn.ReceiveFunc{c.receive}, 0, nil
@@ -129,16 +129,8 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
 	return
 }
 
-func (c *ClientBind) Reset() {
-	common.Close(common.PtrOrNil(c.conn))
-}
-
 func (c *ClientBind) Close() error {
 	common.Close(common.PtrOrNil(c.conn))
-	if c.done == nil {
-		c.done = make(chan struct{})
-		return nil
-	}
 	select {
 	case <-c.done:
 	default:
@@ -165,7 +157,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
 			}
 			copy(b[1:4], reserved[:])
 		}
-		_, err = udpConn.WriteTo(b, M.SocksaddrFromNetIP(destination))
+		_, err = udpConn.WriteToUDPAddrPort(b, destination)
 		if err != nil {
 			udpConn.Close()
 			return err
@@ -192,10 +184,18 @@ func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved
 
 type wireConn struct {
 	net.PacketConn
+	conn   net.Conn
 	access sync.Mutex
 	done   chan struct{}
 }
 
+func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+	if w.conn != nil {
+		return w.conn.Write(b)
+	}
+	return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
+}
+
 func (w *wireConn) Close() error {
 	w.access.Lock()
 	defer w.access.Unlock()