From b6068cea6b2351454927d4d08e44aa1af630dd3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 20 Apr 2023 13:16:31 +0800 Subject: [PATCH] Update wireguard-go --- go.mod | 2 +- go.sum | 4 +- transport/wireguard/client_bind.go | 43 ++++++++----- transport/wireguard/device_stack.go | 94 +++++++++++++++++---------- transport/wireguard/device_system.go | 60 +++++++++++++----- transport/wireguard/server_bind.go | 95 ---------------------------- 6 files changed, 135 insertions(+), 163 deletions(-) delete mode 100644 transport/wireguard/server_bind.go diff --git a/go.mod b/go.mod index 8c7945ca..30602f9c 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e - github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c + github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.3 go.etcd.io/bbolt v1.3.7 diff --git a/go.sum b/go.sum index ad57d1b1..d09a89e3 100644 --- a/go.sum +++ b/go.sum @@ -131,8 +131,8 @@ github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfI github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM= github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs= github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY= -github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo= -github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI= +github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77 h1:g6QtRWQ2dKX7EQP++1JLNtw4C2TNxd4/ov8YUpOPOSo= +github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77/go.mod h1:pJDdXzZIwJ+2vmnT0TKzmf8meeum+e2mTDSehw79eE0= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 79c08d47..61f5ab7c 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -101,7 +101,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1 return []conn.ReceiveFunc{c.receive}, 0, nil } -func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { +func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) { udpConn, err := c.connect() if err != nil { select { @@ -113,22 +113,26 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { err = nil return } - n, addr, err := udpConn.ReadFrom(b) + n, addr, err := udpConn.ReadFrom(packets[0]) if err != nil { udpConn.Close() select { case <-c.done: default: c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet")) + err = nil } return } + sizes[0] = n if n > 3 { + b := packets[0] b[1] = 0 b[2] = 0 b[3] = 0 } - ep = Endpoint(M.SocksaddrFromNet(addr)) + eps[0] = Endpoint(M.SocksaddrFromNet(addr)) + count = 1 return } @@ -155,32 +159,39 @@ func (c *ClientBind) SetMark(mark uint32) error { return nil } -func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { +func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { udpConn, err := c.connect() if err != nil { return err } destination := M.Socksaddr(ep.(Endpoint)) - if len(b) > 3 { - reserved, loaded := c.reservedForEndpoint[destination] - if !loaded { - reserved = c.reserved + for _, b := range bufs { + if len(b) > 3 { + reserved, loaded := c.reservedForEndpoint[destination] + if !loaded { + reserved = c.reserved + } + b[1] = reserved[0] + b[2] = reserved[1] + b[3] = reserved[2] + } + _, err = udpConn.WriteTo(b, destination) + if err != nil { + udpConn.Close() + return err } - b[1] = reserved[0] - b[2] = reserved[1] - b[3] = reserved[2] } - _, err = udpConn.WriteTo(b, destination) - if err != nil { - udpConn.Close() - } - return err + return nil } func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { return Endpoint(M.ParseSocksaddr(s)), nil } +func (c *ClientBind) BatchSize() int { + return 1 +} + type wireConn struct { net.PacketConn access sync.Mutex diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index b2981e36..37d2f677 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -8,10 +8,11 @@ import ( "net/netip" "os" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + wgTun "github.com/sagernet/wireguard-go/tun" "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" @@ -30,14 +31,15 @@ var _ Device = (*StackDevice)(nil) const defaultNIC tcpip.NICID = 1 type StackDevice struct { - stack *stack.Stack - mtu uint32 - events chan tun.Event - outbound chan *stack.PacketBuffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + addr4 tcpip.Address + addr6 tcpip.Address } func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { @@ -47,11 +49,12 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er HandleLocal: true, }) tunDevice := &StackDevice{ - stack: ipStack, - mtu: mtu, - events: make(chan tun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - done: make(chan struct{}), + stack: ipStack, + mtu: mtu, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), } err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) if err != nil { @@ -144,8 +147,16 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } +func (w *StackDevice) Inet4Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr4)) +} + +func (w *StackDevice) Inet6Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr6)) +} + func (w *StackDevice) Start() error { - w.events <- tun.EventUp + w.events <- wgTun.EventUp return nil } @@ -153,41 +164,52 @@ func (w *StackDevice) File() *os.File { return nil } -func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { +func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { select { case packetBuffer, ok := <-w.outbound: if !ok { return 0, os.ErrClosed } defer packetBuffer.DecRef() + p := bufs[0] p = p[offset:] + n := 0 for _, slice := range packetBuffer.AsSlices() { n += copy(p[n:], slice) } + sizes[0] = n + count = 1 + return + case packet := <-w.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + count = 1 return case <-w.done: return 0, os.ErrClosed } } -func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { - p = p[offset:] - if len(p) == 0 { - return +func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { + for _, b := range bufs { + b = b[offset:] + if len(b) == 0 { + continue + } + var networkProtocol tcpip.NetworkProtocolNumber + switch header.IPVersion(b) { + case header.IPv4Version: + networkProtocol = header.IPv4ProtocolNumber + case header.IPv6Version: + networkProtocol = header.IPv6ProtocolNumber + } + packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData(b), + }) + w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) + packetBuffer.DecRef() + count++ } - var networkProtocol tcpip.NetworkProtocolNumber - switch header.IPVersion(p) { - case header.IPv4Version: - networkProtocol = header.IPv4ProtocolNumber - case header.IPv6Version: - networkProtocol = header.IPv6ProtocolNumber - } - packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: bufferv2.MakeWithData(p), - }) - defer packetBuffer.DecRef() - w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) - n = len(p) return } @@ -203,7 +225,7 @@ func (w *StackDevice) Name() (string, error) { return "sing-box", nil } -func (w *StackDevice) Events() chan tun.Event { +func (w *StackDevice) Events() <-chan wgTun.Event { return w.events } @@ -222,6 +244,10 @@ func (w *StackDevice) Close() error { return nil } +func (w *StackDevice) BatchSize() int { + return 1 +} + var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint StackDevice diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index d4316422..f2325a9c 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -23,16 +23,10 @@ type SystemDevice struct { name string mtu int events chan wgTun.Event + addr4 netip.Addr + addr6 netip.Addr } -/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) { - gTun, isGTun := w.device.(tun.GVisorTun) - if !isGTun { - return nil, tun.ErrGVisorUnsupported - } - return gTun.NewEndpoint() -}*/ - func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) { var inet4Addresses []netip.Prefix var inet6Addresses []netip.Prefix @@ -55,11 +49,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes if err != nil { return nil, err } + var inet4Address netip.Addr + var inet6Address netip.Addr + if len(inet4Addresses) > 0 { + inet4Address = inet4Addresses[0].Addr() + } + if len(inet6Addresses) > 0 { + inet6Address = inet6Addresses[0].Addr() + } return &SystemDevice{ - dialer.NewDefault(router, option.DialerOptions{ + dialer: dialer.NewDefault(router, option.DialerOptions{ BindInterface: interfaceName, }), - tunInterface, interfaceName, int(mtu), make(chan wgTun.Event), + device: tunInterface, + name: interfaceName, + mtu: int(mtu), + events: make(chan wgTun.Event), + addr4: inet4Address, + addr6: inet6Address, }, nil } @@ -71,6 +78,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr return w.dialer.ListenPacket(ctx, destination) } +func (w *SystemDevice) Inet4Address() netip.Addr { + return w.addr4 +} + +func (w *SystemDevice) Inet6Address() netip.Addr { + return w.addr6 +} + func (w *SystemDevice) Start() error { w.events <- wgTun.EventUp return nil @@ -80,12 +95,23 @@ func (w *SystemDevice) File() *os.File { return nil } -func (w *SystemDevice) Read(bytes []byte, index int) (int, error) { - return w.device.Read(bytes[index-tun.PacketOffset:]) +func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { + sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:]) + if err == nil { + count = 1 + } + return } -func (w *SystemDevice) Write(bytes []byte, index int) (int, error) { - return w.device.Write(bytes[index:]) +func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) { + for _, b := range bufs { + _, err = w.device.Write(b[offset:]) + if err != nil { + return + } + count++ + } + return } func (w *SystemDevice) Flush() error { @@ -100,10 +126,14 @@ func (w *SystemDevice) Name() (string, error) { return w.name, nil } -func (w *SystemDevice) Events() chan wgTun.Event { +func (w *SystemDevice) Events() <-chan wgTun.Event { return w.events } func (w *SystemDevice) Close() error { return w.device.Close() } + +func (w *SystemDevice) BatchSize() int { + return 1 +} diff --git a/transport/wireguard/server_bind.go b/transport/wireguard/server_bind.go deleted file mode 100644 index 8e61897a..00000000 --- a/transport/wireguard/server_bind.go +++ /dev/null @@ -1,95 +0,0 @@ -package wireguard - -import ( - "io" - - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/conn" -) - -var _ conn.Bind = (*ServerBind)(nil) - -type ServerBind struct { - inbound chan serverPacket - done chan struct{} - writeBack N.PacketWriter -} - -func NewServerBind(writeBack N.PacketWriter) *ServerBind { - return &ServerBind{ - inbound: make(chan serverPacket, 256), - done: make(chan struct{}), - writeBack: writeBack, - } -} - -func (s *ServerBind) Abort() error { - select { - case <-s.done: - return io.ErrClosedPipe - default: - close(s.done) - } - return nil -} - -type serverPacket struct { - buffer *buf.Buffer - source M.Socksaddr -} - -func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - fns = []conn.ReceiveFunc{s.receive} - return -} - -func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { - select { - case packet := <-s.inbound: - defer packet.buffer.Release() - n = copy(b, packet.buffer.Bytes()) - ep = Endpoint(packet.source) - return - case <-s.done: - err = io.ErrClosedPipe - return - } -} - -func (s *ServerBind) WriteIsThreadUnsafe() { -} - -func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - select { - case s.inbound <- serverPacket{ - buffer: buffer, - source: destination, - }: - return nil - case <-s.done: - return io.ErrClosedPipe - } -} - -func (s *ServerBind) Close() error { - return nil -} - -func (s *ServerBind) SetMark(mark uint32) error { - return nil -} - -func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error { - return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint))) -} - -func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) { - destination := M.ParseSocksaddr(addr) - if !destination.IsValid() || destination.Port == 0 { - return nil, E.New("invalid endpoint: ", addr) - } - return Endpoint(destination), nil -}