diff --git a/go.mod b/go.mod index 4acdd79..9f309fc 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,9 @@ go 1.20 require ( github.com/gofrs/uuid/v5 v5.3.0 - github.com/sagernet/quic-go v0.48.1-beta.1 - github.com/sagernet/sing v0.6.0-alpha.17 - golang.org/x/crypto v0.29.0 + github.com/sagernet/quic-go v0.48.2-beta.1 + github.com/sagernet/sing v0.6.0-beta.9 + golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 ) @@ -17,7 +17,7 @@ require ( github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.27.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.24.0 // indirect ) diff --git a/go.sum b/go.sum index cf11c9b..72200f5 100644 --- a/go.sum +++ b/go.sum @@ -19,23 +19,23 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/sagernet/quic-go v0.48.1-beta.1 h1:ElPaV5yzlXIKZpqFMAcUGax6vddi3zt4AEpT94Z0vwo= -github.com/sagernet/quic-go v0.48.1-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k= -github.com/sagernet/sing v0.6.0-alpha.17 h1:y//jVrBjJMW6tRpA/ElT7+Snp3DHEJvO60D+DByg/Es= -github.com/sagernet/sing v0.6.0-alpha.17/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/quic-go v0.48.2-beta.1 h1:W0plrLWa1XtOWDTdX3CJwxmQuxkya12nN5BRGZ87kEg= +github.com/sagernet/quic-go v0.48.2-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k= +github.com/sagernet/sing v0.6.0-beta.9 h1:P8lKa5hN53fRNAVCIKy5cWd6/kLO5c4slhdsfehSmHs= +github.com/sagernet/sing v0.6.0-beta.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= -golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= diff --git a/hysteria2/client.go b/hysteria2/client.go index 981f123..c5b54f6 100644 --- a/hysteria2/client.go +++ b/hysteria2/client.go @@ -3,11 +3,14 @@ package hysteria2 import ( "context" "io" + "math" "net" "net/http" "net/url" "os" "runtime" + "strconv" + "strings" "sync" "time" @@ -36,6 +39,8 @@ type ClientOptions struct { Logger logger.Logger BrutalDebug bool ServerAddress M.Socksaddr + ServerPorts []string + HopInterval time.Duration SendBPS uint64 ReceiveBPS uint64 SalamanderPassword string @@ -50,6 +55,8 @@ type Client struct { logger logger.Logger brutalDebug bool serverAddr M.Socksaddr + serverPorts []uint16 + hopInterval time.Duration sendBPS uint64 receiveBPS uint64 salamanderPassword string @@ -76,12 +83,22 @@ func NewClient(options ClientOptions) (*Client, error) { if len(options.TLSConfig.NextProtos()) == 0 { options.TLSConfig.SetNextProtos([]string{http3.NextProtoH3}) } + var serverPorts []uint16 + if len(options.ServerPorts) > 0 { + var err error + serverPorts, err = parsePorts(options.ServerPorts) + if err != nil { + return nil, err + } + } return &Client{ ctx: options.Context, dialer: options.Dialer, logger: options.Logger, brutalDebug: options.BrutalDebug, serverAddr: options.ServerAddress, + serverPorts: serverPorts, + hopInterval: options.HopInterval, sendBPS: options.SendBPS, receiveBPS: options.ReceiveBPS, salamanderPassword: options.SalamanderPassword, @@ -92,6 +109,38 @@ func NewClient(options ClientOptions) (*Client, error) { }, nil } +func parsePorts(serverPorts []string) ([]uint16, error) { + var portList []uint16 + for _, portRange := range serverPorts { + if !strings.Contains(portRange, ":") { + return nil, E.New("bad port range: ", portRange) + } + subIndex := strings.Index(portRange, ":") + var ( + start, end uint64 + err error + ) + if subIndex > 0 { + start, err = strconv.ParseUint(portRange[:subIndex], 10, 16) + if err != nil { + return nil, E.Cause(err, E.Cause(err, "bad port range: ", portRange)) + } + } + if subIndex == len(portRange)-1 { + end = math.MaxUint16 + } else { + end, err = strconv.ParseUint(portRange[subIndex+1:], 10, 16) + if err != nil { + return nil, E.Cause(err, E.Cause(err, "bad port range: ", portRange)) + } + } + for i := start; i <= end; i++ { + portList = append(portList, uint16(i)) + } + } + return portList, nil +} + func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { conn := c.conn if conn != nil && conn.active() { @@ -111,19 +160,34 @@ func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { } func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { - udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) + dialFunc := func(serverAddr M.Socksaddr) (net.PacketConn, error) { + udpConn, err := c.dialer.DialContext(c.ctx, "udp", serverAddr) + if err != nil { + return nil, err + } + var packetConn net.PacketConn + packetConn = bufio.NewUnbindPacketConn(udpConn) + if c.salamanderPassword != "" { + packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword)) + } + return packetConn, nil + } + var ( + packetConn net.PacketConn + err error + ) + if len(c.serverPorts) == 0 { + packetConn, err = dialFunc(c.serverAddr) + } else { + packetConn, err = NewHopPacketConn(dialFunc, c.serverAddr, c.serverPorts, c.hopInterval) + } if err != nil { return nil, err } - var packetConn net.PacketConn - packetConn = bufio.NewUnbindPacketConn(udpConn) - if c.salamanderPassword != "" { - packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword)) - } var quicConn quic.EarlyConnection http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig) if err != nil { - udpConn.Close() + packetConn.Close() return nil, err } request := &http.Request{ @@ -141,14 +205,14 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { if quicConn != nil { quicConn.CloseWithError(0, "") } - udpConn.Close() + packetConn.Close() return nil, err } if response.StatusCode != protocol.StatusAuthOK { if quicConn != nil { quicConn.CloseWithError(0, "") } - udpConn.Close() + packetConn.Close() return nil, E.New("authentication failed, status code: ", response.StatusCode) } response.Body.Close() @@ -172,7 +236,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { } conn := &clientQUICConnection{ quicConn: quicConn, - rawConn: udpConn, + rawConn: packetConn, connDone: make(chan struct{}), udpDisabled: !authResponse.UDPEnabled, udpConnMap: make(map[uint32]*udpPacketConn), diff --git a/hysteria2/hop.go b/hysteria2/hop.go new file mode 100644 index 0000000..8714b38 --- /dev/null +++ b/hysteria2/hop.go @@ -0,0 +1,269 @@ +package hysteria2 + +import ( + "errors" + "math/rand" + "net" + "os" + "sync" + "syscall" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +const ( + packetQueueSize = 1024 + udpBufferSize = 2048 + defaultHopInterval = 30 * time.Second +) + +type HopPacketConn struct { + dialFunc func(M.Socksaddr) (net.PacketConn, error) + destination M.Socksaddr + ports []uint16 + interval time.Duration + access sync.Mutex + prevConn net.PacketConn + currentConn net.PacketConn + portIndex int + readBufferSize int + writeBufferSize int + packetChan chan *buf.Buffer + errChan chan error + doneChan chan struct{} + done bool +} + +func NewHopPacketConn( + dialFunc func(M.Socksaddr) (net.PacketConn, error), + destination M.Socksaddr, + ports []uint16, + interval time.Duration, +) (*HopPacketConn, error) { + if interval == 0 { + interval = defaultHopInterval + } + hopConn := &HopPacketConn{ + dialFunc: dialFunc, + destination: destination, + ports: ports, + interval: interval, + packetChan: make(chan *buf.Buffer, packetQueueSize), + errChan: make(chan error), + doneChan: make(chan struct{}), + } + currentConn, err := dialFunc(hopConn.nextAddr()) + if err != nil { + return nil, err + } + hopConn.currentConn = currentConn + go hopConn.recvLoop(currentConn) + go hopConn.hopLoop() + return hopConn, nil +} + +func (c *HopPacketConn) nextAddr() M.Socksaddr { + c.portIndex = rand.Intn(len(c.ports)) + return M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: c.ports[c.portIndex], + } +} + +func (c *HopPacketConn) recvLoop(conn net.PacketConn) { + for { + buffer := buf.NewSize(udpBufferSize) + n, _, err := conn.ReadFrom(buffer.FreeBytes()) + if err != nil { + buffer.Release() + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // Only pass through timeout errors here, not permanent errors + // like connection closed. Connection close is normal as we close + // the old connection to exit this loop every time we hop. + c.errChan <- netErr + } + return + } + buffer.Truncate(n) + select { + case c.packetChan <- buffer: + default: + buffer.Release() + } + } +} + +func (c *HopPacketConn) hopLoop() { + ticker := time.NewTicker(c.interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.hop() + case <-c.doneChan: + return + } + } +} + +func (c *HopPacketConn) hop() { + c.access.Lock() + defer c.access.Unlock() + if c.done { + return + } + nextAddr := c.nextAddr() + newConn, err := c.dialFunc(nextAddr) + if err != nil { + return + } + if c.prevConn != nil { + c.prevConn.Close() + } + c.prevConn = c.currentConn + c.currentConn = newConn + if c.readBufferSize > 0 { + _ = trySetReadBuffer(newConn, c.readBufferSize) + } + if c.writeBufferSize > 0 { + _ = trySetWriteBuffer(newConn, c.writeBufferSize) + } + go c.recvLoop(newConn) +} + +func (c *HopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + for { + select { + case packet := <-c.packetChan: + n = copy(b, packet.Bytes()) + packet.Release() + return n, (*hopFakeAddr)(nil), nil + case err = <-c.errChan: + return 0, nil, err + case <-c.doneChan: + return 0, nil, net.ErrClosed + } + } +} + +func (c *HopPacketConn) WriteTo(b []byte, _ net.Addr) (n int, err error) { + c.access.Lock() + defer c.access.Unlock() + if c.done { + return 0, net.ErrClosed + } + return c.currentConn.WriteTo(b, (*hopFakeAddr)(nil)) +} + +func (c *HopPacketConn) Close() error { + c.access.Lock() + defer c.access.Unlock() + if c.done { + return nil + } + if c.prevConn != nil { + _ = c.prevConn.Close() + } + err := c.currentConn.Close() + close(c.doneChan) + c.done = true + return err +} + +func (c *HopPacketConn) LocalAddr() net.Addr { + c.access.Lock() + defer c.access.Unlock() + return c.currentConn.LocalAddr() +} + +func (c *HopPacketConn) SetDeadline(t time.Time) error { + c.access.Lock() + defer c.access.Unlock() + if c.prevConn != nil { + _ = c.prevConn.SetDeadline(t) + } + return c.currentConn.SetDeadline(t) +} + +func (c *HopPacketConn) SetReadDeadline(t time.Time) error { + c.access.Lock() + defer c.access.Unlock() + if c.prevConn != nil { + _ = c.prevConn.SetReadDeadline(t) + } + return c.currentConn.SetReadDeadline(t) +} + +func (c *HopPacketConn) SetWriteDeadline(t time.Time) error { + c.access.Lock() + defer c.access.Unlock() + if c.prevConn != nil { + _ = c.prevConn.SetWriteDeadline(t) + } + return c.currentConn.SetWriteDeadline(t) +} + +func (c *HopPacketConn) SetReadBuffer(bytes int) error { + c.access.Lock() + defer c.access.Unlock() + c.readBufferSize = bytes + if c.prevConn != nil { + _ = trySetReadBuffer(c.prevConn, bytes) + } + return trySetReadBuffer(c.currentConn, bytes) +} + +func (c *HopPacketConn) SetWriteBuffer(bytes int) error { + c.access.Lock() + defer c.access.Unlock() + c.writeBufferSize = bytes + if c.prevConn != nil { + _ = trySetWriteBuffer(c.prevConn, bytes) + } + return trySetWriteBuffer(c.currentConn, bytes) +} + +func (c *HopPacketConn) SyscallConn() (syscall.RawConn, error) { + c.access.Lock() + defer c.access.Unlock() + rawConn, isRawConn := common.Cast[syscall.Conn](c.currentConn) + if !isRawConn { + return nil, os.ErrInvalid + } + return rawConn.SyscallConn() +} + +func trySetReadBuffer(pc any, bytes int) error { + udpConn, isUDPConn := common.Cast[interface { + SetReadBuffer(bytes int) error + }](pc) + if !isUDPConn { + return nil + } + return udpConn.SetReadBuffer(bytes) +} + +func trySetWriteBuffer(pc any, bytes int) error { + udpConn, isUDPConn := common.Cast[interface { + SetWriteBuffer(bytes int) error + }](pc) + if !isUDPConn { + return nil + } + return udpConn.SetWriteBuffer(bytes) +} + +type hopFakeAddr struct{} + +func (a *hopFakeAddr) Network() string { + return "udphop" +} + +func (a *hopFakeAddr) String() string { + return "" +} diff --git a/hysteria2/salamander.go b/hysteria2/salamander.go index c057d1f..b83cac6 100644 --- a/hysteria2/salamander.go +++ b/hysteria2/salamander.go @@ -69,6 +69,10 @@ func (s *SalamanderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro return len(p), nil } +func (s *SalamanderPacketConn) Upstream() any { + return s.PacketConn +} + type VectorisedSalamanderPacketConn struct { SalamanderPacketConn writer N.VectorisedPacketWriter