package cmd import ( "errors" "fmt" "time" "github.com/spf13/cobra" "github.com/spf13/viper" "go.uber.org/zap" "github.com/apernet/hysteria/core/v2/client" hyErrors "github.com/apernet/hysteria/core/v2/errors" "github.com/apernet/hysteria/extras/v2/outbounds" "github.com/apernet/hysteria/extras/v2/outbounds/speedtest" ) var ( skipDownload bool skipUpload bool dataSize uint32 useBytes bool speedtestAddr = fmt.Sprintf("%s:%d", outbounds.SpeedtestDest, 0) ) // speedtestCmd represents the speedtest command var speedtestCmd = &cobra.Command{ Use: "speedtest", Short: "Speed test mode", Long: "Perform a speed test through the proxy server. The server must have speed test support enabled.", Run: runSpeedtest, } func init() { initSpeedtestFlags() rootCmd.AddCommand(speedtestCmd) } func initSpeedtestFlags() { speedtestCmd.Flags().BoolVar(&skipDownload, "skip-download", false, "Skip download test") speedtestCmd.Flags().BoolVar(&skipUpload, "skip-upload", false, "Skip upload test") speedtestCmd.Flags().Uint32Var(&dataSize, "data-size", 1024*1024*100, "Data size for download and upload tests") speedtestCmd.Flags().BoolVar(&useBytes, "use-bytes", false, "Use bytes per second instead of bits per second") } func runSpeedtest(cmd *cobra.Command, args []string) { logger.Info("speed test mode") if err := viper.ReadInConfig(); err != nil { logger.Fatal("failed to read client config", zap.Error(err)) } var config clientConfig if err := viper.Unmarshal(&config); err != nil { logger.Fatal("failed to parse client config", zap.Error(err)) } hyConfig, err := config.Config() if err != nil { logger.Fatal("failed to load client config", zap.Error(err)) } c, info, err := client.NewClient(hyConfig) if err != nil { logger.Fatal("failed to initialize client", zap.Error(err)) } defer c.Close() logger.Info("connected to server", zap.Bool("udpEnabled", info.UDPEnabled), zap.Uint64("tx", info.Tx)) if !skipDownload { runDownloadTest(c) } if !skipUpload { runUploadTest(c) } } func runDownloadTest(c client.Client) { logger.Info("performing download test") downConn, err := c.TCP(speedtestAddr) if err != nil { if errors.As(err, &hyErrors.DialError{}) { logger.Fatal("failed to connect (server may not support speed test)", zap.Error(err)) } else { logger.Fatal("failed to connect", zap.Error(err)) } } defer downConn.Close() downClient := &speedtest.Client{Conn: downConn} currentTotal := uint32(0) err = downClient.Download(dataSize, func(d time.Duration, b uint32, done bool) { if !done { currentTotal += b logger.Info("downloading", zap.Uint32("bytes", b), zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), zap.String("speed", formatSpeed(b, d, useBytes))) } else { logger.Info("download complete", zap.Uint32("bytes", b), zap.String("speed", formatSpeed(b, d, useBytes))) } }) if err != nil { logger.Fatal("download test failed", zap.Error(err)) } logger.Info("download test complete") } func runUploadTest(c client.Client) { logger.Info("performing upload test") upConn, err := c.TCP(speedtestAddr) if err != nil { if errors.As(err, &hyErrors.DialError{}) { logger.Fatal("failed to connect (server may not support speed test)", zap.Error(err)) } else { logger.Fatal("failed to connect", zap.Error(err)) } } defer upConn.Close() upClient := &speedtest.Client{Conn: upConn} currentTotal := uint32(0) err = upClient.Upload(dataSize, func(d time.Duration, b uint32, done bool) { if !done { currentTotal += b logger.Info("uploading", zap.Uint32("bytes", b), zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), zap.String("speed", formatSpeed(b, d, useBytes))) } else { logger.Info("upload complete", zap.Uint32("bytes", b), zap.String("speed", formatSpeed(b, d, useBytes))) } }) if err != nil { logger.Fatal("upload test failed", zap.Error(err)) } logger.Info("upload test complete") } func formatSpeed(bytes uint32, duration time.Duration, useBytes bool) string { speed := float64(bytes) / duration.Seconds() var units []string if useBytes { units = []string{"B/s", "KB/s", "MB/s", "GB/s"} } else { units = []string{"bps", "Kbps", "Mbps", "Gbps"} speed *= 8 } unitIndex := 0 for speed > 1024 && unitIndex < len(units)-1 { speed /= 1024 unitIndex++ } return fmt.Sprintf("%.2f %s", speed, units[unitIndex]) }