Improve read wait interface &

Refactor Authenticator interface to struct &
Update smux &
Update gVisor to 20231204.0 &
Update quic-go to v0.40.1 &
Update wireguard-go &
Add GSO support for TUN/WireGuard &
Fix router pre-start &
Fix bind forwarder to interface for systems stack
This commit is contained in:
世界 2023-12-20 20:00:00 +08:00
parent 35fd9de3ff
commit 89c723e3e4
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
48 changed files with 902 additions and 658 deletions

View file

@ -17,16 +17,16 @@ func (c *NATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
type waitNATPacketConn struct {
*NATPacketConn
waiter N.PacketReadWaiter
readWaiter N.PacketReadWaiter
}
func (c *waitNATPacketConn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.waiter.InitializeReadWaiter(newBuffer)
func (c *waitNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return c.readWaiter.InitializeReadWaiter(options)
}
func (c *waitNATPacketConn) WaitReadPacket() (destination M.Socksaddr, err error) {
destination, err = c.waiter.WaitReadPacket()
if socksaddrWithoutPort(destination) == c.origin {
func (c *waitNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = c.readWaiter.WaitReadPacket()
if err == nil && socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{
Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn,

View file

@ -53,7 +53,7 @@ func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata
case CommandTCP:
return handler.NewConnection(ctx, stream, metadata)
case CommandUDP:
return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata)
return handler.NewPacketConnection(ctx, &PacketConn{Conn: stream}, metadata)
default:
return E.New("unknown command ", command)
}

View file

@ -85,9 +85,10 @@ func (c *ClientConn) Upstream() any {
type ClientPacketConn struct {
net.Conn
access sync.Mutex
key [KeyLength]byte
headerWritten bool
access sync.Mutex
key [KeyLength]byte
headerWritten bool
readWaitOptions N.ReadWaitOptions
}
func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {

View file

@ -0,0 +1,45 @@
package trojan
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
var _ N.PacketReadWaiter = (*ClientPacketConn)(nil)
func (c *ClientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *ClientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
}
err = rw.SkipN(c.Conn, 2)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return
}
c.readWaitOptions.PostReturn(buffer)
return
}

View file

@ -105,7 +105,7 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata
case CommandTCP:
return s.handler.NewConnection(ctx, conn, metadata)
case CommandUDP:
return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata)
return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
// case CommandMux:
default:
return HandleMuxConnection(ctx, conn, metadata, s.handler)
@ -122,6 +122,7 @@ func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Met
type PacketConn struct {
net.Conn
readWaitOptions N.ReadWaitOptions
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {

View file

@ -0,0 +1,45 @@
package trojan
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
var _ N.PacketReadWaiter = (*PacketConn)(nil)
func (c *PacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *PacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
}
err = rw.SkipN(c.Conn, 2)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return
}
c.readWaitOptions.PostReturn(buffer)
return
}

View file

@ -12,7 +12,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
)
@ -22,33 +21,27 @@ type ClientBind struct {
ctx context.Context
errorHandler E.Handler
dialer N.Dialer
reservedForEndpoint map[M.Socksaddr][3]uint8
reservedForEndpoint map[netip.AddrPort][3]uint8
connAccess sync.Mutex
conn *wireConn
done chan struct{}
isConnect bool
connectAddr M.Socksaddr
connectAddr netip.AddrPort
reserved [3]uint8
pauseManager pause.Manager
}
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
return &ClientBind{
ctx: ctx,
errorHandler: errorHandler,
dialer: dialer,
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
isConnect: isConnect,
connectAddr: connectAddr,
reserved: reserved,
pauseManager: pause.ManagerFromContext(ctx),
}
}
func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
c.reservedForEndpoint[destination] = reserved
}
func (c *ClientBind) connect() (*wireConn, error) {
serverConn := c.conn
if serverConn != nil {
@ -71,16 +64,13 @@ func (c *ClientBind) connect() (*wireConn, error) {
}
}
if c.isConnect {
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
if err != nil {
return nil, err
}
c.conn = &wireConn{
PacketConn: &bufio.UnbindPacketConn{
ExtendedConn: bufio.NewExtendedConn(udpConn),
Addr: c.connectAddr,
},
done: make(chan struct{}),
PacketConn: bufio.NewUnbindPacketConn(udpConn),
done: make(chan struct{}),
}
} else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
@ -116,7 +106,6 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
err = nil
time.Sleep(time.Second)
c.pauseManager.WaitActive()
return
}
n, addr, err := udpConn.ReadFrom(packets[0])
@ -133,11 +122,9 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
sizes[0] = n
if n > 3 {
b := packets[0]
b[1] = 0
b[2] = 0
b[3] = 0
common.ClearArray(b[1:4])
}
eps[0] = Endpoint(M.SocksaddrFromNet(addr))
eps[0] = Endpoint(M.AddrPortFromNet(addr))
count = 1
return
}
@ -170,18 +157,16 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
if err != nil {
return err
}
destination := M.Socksaddr(ep.(Endpoint))
destination := netip.AddrPort(ep.(Endpoint))
for _, b := range bufs {
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
if !loaded {
reserved = c.reserved
}
b[1] = reserved[0]
b[2] = reserved[1]
b[3] = reserved[2]
copy(b[1:4], reserved[:])
}
_, err = udpConn.WriteTo(b, destination)
_, err = udpConn.WriteTo(b, M.SocksaddrFromNetIP(destination))
if err != nil {
udpConn.Close()
return err
@ -191,13 +176,21 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
}
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(M.ParseSocksaddr(s)), nil
ap, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return Endpoint(ap), nil
}
func (c *ClientBind) BatchSize() int {
return 1
}
func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
c.reservedForEndpoint[destination] = reserved
}
type wireConn struct {
net.PacketConn
access sync.Mutex

View file

@ -265,7 +265,7 @@ func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
}
func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
return stack.CapabilityRXChecksumOffload
}
func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {

View file

@ -2,6 +2,7 @@ package wireguard
import (
"context"
"errors"
"net"
"net/netip"
"os"
@ -11,6 +12,7 @@ import (
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
wgTun "github.com/sagernet/wireguard-go/tun"
@ -19,16 +21,17 @@ import (
var _ Device = (*SystemDevice)(nil)
type SystemDevice struct {
dialer N.Dialer
device tun.Tun
name string
mtu int
events chan wgTun.Event
addr4 netip.Addr
addr6 netip.Addr
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
name string
mtu int
events chan wgTun.Event
addr4 netip.Addr
addr6 netip.Addr
}
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix
for _, prefixes := range localPrefixes {
@ -46,6 +49,7 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
Inet4Address: inet4Addresses,
Inet6Address: inet6Addresses,
MTU: mtu,
GSO: gso,
})
if err != nil {
return nil, err
@ -58,16 +62,25 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
if len(inet6Addresses) > 0 {
inet6Address = inet6Addresses[0].Addr()
}
var batchDevice tun.LinuxTUN
if gso {
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
if !isBatchTUN {
return nil, E.New("GSO is not supported on current platform")
}
batchDevice = batchTUN
}
return &SystemDevice{
dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{
BindInterface: interfaceName,
})),
device: tunInterface,
name: interfaceName,
mtu: int(mtu),
events: make(chan wgTun.Event),
addr4: inet4Address,
addr6: inet6Address,
device: tunInterface,
batchDevice: batchDevice,
name: interfaceName,
mtu: int(mtu),
events: make(chan wgTun.Event),
addr4: inet4Address,
addr6: inet6Address,
}, nil
}
@ -97,21 +110,31 @@ func (w *SystemDevice) File() *os.File {
}
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil {
count = 1
if w.batchDevice != nil {
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
} else {
sizes[0], err = w.device.Read(bufs[0][offset:])
if err == nil {
count = 1
} else if errors.Is(err, tun.ErrTooManySegments) {
err = wgTun.ErrTooManySegments
}
}
return
}
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
_, err = w.device.Write(b[offset:])
if err != nil {
return
if w.batchDevice != nil {
return 0, w.batchDevice.BatchWrite(bufs, offset)
} else {
for _, b := range bufs {
_, err = w.device.Write(b[offset:])
if err != nil {
return
}
}
count++
}
// WireGuard will not read count
return
}
@ -136,5 +159,8 @@ func (w *SystemDevice) Close() error {
}
func (w *SystemDevice) BatchSize() int {
if w.batchDevice != nil {
return w.batchDevice.BatchSize()
}
return 1
}

View file

@ -3,13 +3,12 @@ package wireguard
import (
"net/netip"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/wireguard-go/conn"
)
var _ conn.Endpoint = (*Endpoint)(nil)
type Endpoint M.Socksaddr
type Endpoint netip.AddrPort
func (e Endpoint) ClearSrc() {
}
@ -19,16 +18,16 @@ func (e Endpoint) SrcToString() string {
}
func (e Endpoint) DstToString() string {
return (M.Socksaddr)(e).String()
return (netip.AddrPort)(e).String()
}
func (e Endpoint) DstToBytes() []byte {
b, _ := (M.Socksaddr)(e).AddrPort().MarshalBinary()
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e Endpoint) DstIP() netip.Addr {
return (M.Socksaddr)(e).Addr
return (netip.AddrPort)(e).Addr()
}
func (e Endpoint) SrcIP() netip.Addr {

View file

@ -0,0 +1,148 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"net/netip"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
dns "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type PeerConfig struct {
destination M.Socksaddr
domainStrategy dns.DomainStrategy
Endpoint netip.AddrPort
PublicKey string
PreSharedKey string
AllowedIPs []string
Reserved [3]uint8
}
func (c PeerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.PublicKey
ipcLines += "\nendpoint=" + c.Endpoint.String()
if c.PreSharedKey != "" {
ipcLines += "\npreshared_key=" + c.PreSharedKey
}
for _, allowedIP := range c.AllowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP
}
return ipcLines
}
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
var peers []PeerConfig
if len(options.Peers) > 0 {
for peerIndex, rawPeer := range options.Peers {
peer := PeerConfig{
AllowedIPs: rawPeer.AllowedIPs,
}
destination := rawPeer.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if rawPeer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
} else {
peer := PeerConfig{}
var (
addressHas4 bool
addressHas6 bool
)
for _, localAddress := range options.LocalAddress {
if localAddress.Addr().Is4() {
addressHas4 = true
} else {
addressHas6 = true
}
}
if addressHas4 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
}
if addressHas6 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
}
destination := options.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if options.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(options.Reserved) > 0 {
if len(options.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
return peers, nil
}
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
for peerIndex, peer := range peers {
if peer.Endpoint.IsValid() {
continue
}
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
if err != nil {
if len(peers) == 1 {
return E.Cause(err, "resolve endpoint domain")
} else {
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
}
}
if len(destinationAddresses) == 0 {
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
}
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
}
return nil
}