diff --git a/app/cmd/server.go b/app/cmd/server.go index 3f4bd1e..0f289dc 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -52,6 +52,7 @@ type serverConfig struct { QUIC serverConfigQUIC `mapstructure:"quic"` Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` IgnoreClientBandwidth bool `mapstructure:"ignoreClientBandwidth"` + SpeedTest bool `mapstructure:"speedTest"` DisableUDP bool `mapstructure:"disableUDP"` UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` Auth serverConfigAuth `mapstructure:"auth"` @@ -528,6 +529,11 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")} } + // Speed test + if c.SpeedTest { + uOb = outbounds.NewSpeedtestHandler(uOb) + } + hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb} return nil } diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index 1c4d2f6..935a998 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -56,6 +56,7 @@ func TestServerConfig(t *testing.T) { Down: "100 mbps", }, IgnoreClientBandwidth: true, + SpeedTest: true, DisableUDP: true, UDPIdleTimeout: 120 * time.Second, Auth: serverConfigAuth{ diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index 47d3b19..1ab5d5f 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -36,6 +36,8 @@ bandwidth: ignoreClientBandwidth: true +speedTest: true + disableUDP: true udpIdleTimeout: 120s diff --git a/app/cmd/speedtest.go b/app/cmd/speedtest.go new file mode 100644 index 0000000..67cf71f --- /dev/null +++ b/app/cmd/speedtest.go @@ -0,0 +1,150 @@ +package cmd + +import ( + "fmt" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uber.org/zap" + + "github.com/apernet/hysteria/core/client" + "github.com/apernet/hysteria/extras/outbounds" + "github.com/apernet/hysteria/extras/outbounds/speedtest" +) + +var ( + skipDownload bool + skipUpload bool + dataSize uint32 + useBytes bool + + speedtestAddr = fmt.Sprintf("%s:%d", outbounds.SpeedtestDest, 0) +) + +// speedtestCmd represents the speedtest command +var speedtestCmd = &cobra.Command{ + Use: "speedtest", + Short: "Speed test mode", + Long: "Perform a speed test through the proxy server. The server must have speed test support enabled.", + Run: runSpeedtest, +} + +func init() { + initSpeedtestFlags() + rootCmd.AddCommand(speedtestCmd) +} + +func initSpeedtestFlags() { + speedtestCmd.Flags().BoolVar(&skipDownload, "skip-download", false, "Skip download test") + speedtestCmd.Flags().BoolVar(&skipUpload, "skip-upload", false, "Skip upload test") + speedtestCmd.Flags().Uint32Var(&dataSize, "data-size", 1024*1024*100, "Data size for download and upload tests") + speedtestCmd.Flags().BoolVar(&useBytes, "use-bytes", false, "Use bytes per second instead of bits per second") +} + +func runSpeedtest(cmd *cobra.Command, args []string) { + logger.Info("speed test mode") + + if err := viper.ReadInConfig(); err != nil { + logger.Fatal("failed to read client config", zap.Error(err)) + } + var config clientConfig + if err := viper.Unmarshal(&config); err != nil { + logger.Fatal("failed to parse client config", zap.Error(err)) + } + hyConfig, err := config.Config() + if err != nil { + logger.Fatal("failed to load client config", zap.Error(err)) + } + + c, info, err := client.NewClient(hyConfig) + if err != nil { + logger.Fatal("failed to initialize client", zap.Error(err)) + } + defer c.Close() + logger.Info("connected to server", + zap.Bool("udpEnabled", info.UDPEnabled), + zap.Uint64("tx", info.Tx)) + + if !skipDownload { + runDownloadTest(c) + } + if !skipUpload { + runUploadTest(c) + } +} + +func runDownloadTest(c client.Client) { + logger.Info("performing download test") + downConn, err := c.TCP(speedtestAddr) + if err != nil { + logger.Fatal("failed to connect", zap.Error(err)) + } + defer downConn.Close() + + downClient := &speedtest.Client{Conn: downConn} + currentTotal := uint32(0) + err = downClient.Download(dataSize, func(d time.Duration, b uint32, done bool) { + if !done { + currentTotal += b + logger.Info("downloading", + zap.Uint32("bytes", b), + zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), + zap.String("speed", formatSpeed(b, d, useBytes))) + } else { + logger.Info("download complete", + zap.Uint32("bytes", b), + zap.String("speed", formatSpeed(b, d, useBytes))) + } + }) + if err != nil { + logger.Fatal("download test failed", zap.Error(err)) + } + logger.Info("download test complete") +} + +func runUploadTest(c client.Client) { + logger.Info("performing upload test") + upConn, err := c.TCP(speedtestAddr) + if err != nil { + logger.Fatal("failed to connect", zap.Error(err)) + } + defer upConn.Close() + + upClient := &speedtest.Client{Conn: upConn} + currentTotal := uint32(0) + err = upClient.Upload(dataSize, func(d time.Duration, b uint32, done bool) { + if !done { + currentTotal += b + logger.Info("uploading", + zap.Uint32("bytes", b), + zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), + zap.String("speed", formatSpeed(b, d, useBytes))) + } else { + logger.Info("upload complete", + zap.Uint32("bytes", b), + zap.String("speed", formatSpeed(b, d, useBytes))) + } + }) + if err != nil { + logger.Fatal("upload test failed", zap.Error(err)) + } + logger.Info("upload test complete") +} + +func formatSpeed(bytes uint32, duration time.Duration, useBytes bool) string { + speed := float64(bytes) / duration.Seconds() + var units []string + if useBytes { + units = []string{"B/s", "KB/s", "MB/s", "GB/s"} + } else { + units = []string{"bps", "Kbps", "Mbps", "Gbps"} + speed *= 8 + } + unitIndex := 0 + for speed > 1024 && unitIndex < len(units)-1 { + speed /= 1024 + unitIndex++ + } + return fmt.Sprintf("%.2f %s", speed, units[unitIndex]) +} diff --git a/extras/outbounds/speedtest.go b/extras/outbounds/speedtest.go new file mode 100644 index 0000000..9e5ab43 --- /dev/null +++ b/extras/outbounds/speedtest.go @@ -0,0 +1,36 @@ +package outbounds + +import ( + "net" + + "github.com/apernet/hysteria/extras/outbounds/speedtest" +) + +const ( + SpeedtestDest = "_SpeedTest" +) + +// speedtestHandler is a PluggableOutbound that handles speed test requests. +// It's used to intercept speed test requests and return a pseudo connection that +// implements the speed test protocol. +type speedtestHandler struct { + Next PluggableOutbound +} + +func NewSpeedtestHandler(next PluggableOutbound) PluggableOutbound { + return &speedtestHandler{ + Next: next, + } +} + +func (s *speedtestHandler) TCP(reqAddr *AddrEx) (net.Conn, error) { + if reqAddr.Host == SpeedtestDest { + return speedtest.NewServerConn(), nil + } else { + return s.Next.TCP(reqAddr) + } +} + +func (s *speedtestHandler) UDP(reqAddr *AddrEx) (UDPConn, error) { + return s.Next.UDP(reqAddr) +} diff --git a/extras/outbounds/speedtest/client.go b/extras/outbounds/speedtest/client.go new file mode 100644 index 0000000..ea4c5a6 --- /dev/null +++ b/extras/outbounds/speedtest/client.go @@ -0,0 +1,125 @@ +package speedtest + +import ( + "fmt" + "io" + "net" + "sync/atomic" + "time" +) + +type Client struct { + Conn net.Conn +} + +// Download requests the server to send l bytes of data. +// The callback function cb is called every second with the time since the last call, +// and the number of bytes received in that time. +func (c *Client) Download(l uint32, cb func(time.Duration, uint32, bool)) error { + err := writeDownloadRequest(c.Conn, l) + if err != nil { + return err + } + ok, msg, err := readDownloadResponse(c.Conn) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("server rejected download request: %s", msg) + } + var counter uint32 + stopChan := make(chan struct{}) + defer close(stopChan) + // Call the callback function every second, + // with the time since the last call and the number of bytes received in that time. + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + t := time.Now() + for { + select { + case <-stopChan: + return + case <-ticker.C: + cb(time.Since(t), atomic.SwapUint32(&counter, 0), false) + t = time.Now() + } + } + }() + buf := make([]byte, chunkSize) + startTime := time.Now() + remaining := l + for remaining > 0 { + n := remaining + if n > chunkSize { + n = chunkSize + } + rn, err := c.Conn.Read(buf[:n]) + remaining -= uint32(rn) + atomic.AddUint32(&counter, uint32(rn)) + if err != nil && !(remaining == 0 && err == io.EOF) { + return err + } + } + // One last call to the callback function to report the total time and bytes received. + cb(time.Since(startTime), l, true) + return nil +} + +// Upload requests the server to receive l bytes of data. +// The callback function cb is called every second with the time since the last call, +// and the number of bytes sent in that time. +func (c *Client) Upload(l uint32, cb func(time.Duration, uint32, bool)) error { + err := writeUploadRequest(c.Conn, l) + if err != nil { + return err + } + ok, msg, err := readUploadResponse(c.Conn) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("server rejected upload request: %s", msg) + } + var counter uint32 + stopChan := make(chan struct{}) + defer close(stopChan) + // Call the callback function every second, + // with the time since the last call and the number of bytes sent in that time. + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + t := time.Now() + for { + select { + case <-stopChan: + return + case <-ticker.C: + cb(time.Since(t), atomic.SwapUint32(&counter, 0), false) + t = time.Now() + } + } + }() + buf := make([]byte, chunkSize) + remaining := l + for remaining > 0 { + n := remaining + if n > chunkSize { + n = chunkSize + } + _, err := c.Conn.Write(buf[:n]) + if err != nil { + return err + } + remaining -= n + atomic.AddUint32(&counter, n) + } + // Now we should receive the upload summary from the server. + elapsed, received, err := readUploadSummary(c.Conn) + if err != nil { + return err + } + // One last call to the callback function to report the total time and bytes sent. + cb(elapsed, received, true) + return nil +} diff --git a/extras/outbounds/speedtest/protocol.go b/extras/outbounds/speedtest/protocol.go new file mode 100644 index 0000000..8d1adb8 --- /dev/null +++ b/extras/outbounds/speedtest/protocol.go @@ -0,0 +1,152 @@ +package speedtest + +import ( + "encoding/binary" + "io" + "time" +) + +const ( + typeDownload = 0x1 + typeUpload = 0x2 +) + +// DownloadRequest format: +// 0x1 (byte) +// Request data length (uint32 BE) + +func readDownloadRequest(r io.Reader) (uint32, error) { + var l uint32 + err := binary.Read(r, binary.BigEndian, &l) + return l, err +} + +func writeDownloadRequest(w io.Writer, l uint32) error { + buf := make([]byte, 5) + buf[0] = typeDownload + binary.BigEndian.PutUint32(buf[1:], l) + _, err := w.Write(buf) + return err +} + +// DownloadResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (uint16 BE) +// Message (bytes) + +func readDownloadResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + var msgLen uint16 + if err := binary.Read(r, binary.BigEndian, &msgLen); err != nil { + return false, "", err + } + // No message is fine + if msgLen == 0 { + return status[0] == 0, "", nil + } + msgBuf := make([]byte, msgLen) + _, err := io.ReadFull(r, msgBuf) + if err != nil { + return false, "", err + } + return status[0] == 0, string(msgBuf), nil +} + +func writeDownloadResponse(w io.Writer, ok bool, msg string) error { + sz := 1 + 2 + len(msg) + buf := make([]byte, sz) + if ok { + buf[0] = 0 + } else { + buf[0] = 1 + } + binary.BigEndian.PutUint16(buf[1:], uint16(len(msg))) + copy(buf[3:], msg) + _, err := w.Write(buf) + return err +} + +// UploadRequest format: +// 0x2 (byte) +// Upload data length (uint32 BE) + +func readUploadRequest(r io.Reader) (uint32, error) { + var l uint32 + err := binary.Read(r, binary.BigEndian, &l) + return l, err +} + +func writeUploadRequest(w io.Writer, l uint32) error { + buf := make([]byte, 5) + buf[0] = typeUpload + binary.BigEndian.PutUint32(buf[1:], l) + _, err := w.Write(buf) + return err +} + +// UploadResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (uint16 BE) +// Message (bytes) + +func readUploadResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + var msgLen uint16 + if err := binary.Read(r, binary.BigEndian, &msgLen); err != nil { + return false, "", err + } + // No message is fine + if msgLen == 0 { + return status[0] == 0, "", nil + } + msgBuf := make([]byte, msgLen) + _, err := io.ReadFull(r, msgBuf) + if err != nil { + return false, "", err + } + return status[0] == 0, string(msgBuf), nil +} + +func writeUploadResponse(w io.Writer, ok bool, msg string) error { + sz := 1 + 2 + len(msg) + buf := make([]byte, sz) + if ok { + buf[0] = 0 + } else { + buf[0] = 1 + } + binary.BigEndian.PutUint16(buf[1:], uint16(len(msg))) + copy(buf[3:], msg) + _, err := w.Write(buf) + return err +} + +// UploadSummary format: +// Duration (in milliseconds, uint32 BE) +// Received data length (uint32 BE) + +func readUploadSummary(r io.Reader) (time.Duration, uint32, error) { + var duration uint32 + if err := binary.Read(r, binary.BigEndian, &duration); err != nil { + return 0, 0, err + } + var l uint32 + if err := binary.Read(r, binary.BigEndian, &l); err != nil { + return 0, 0, err + } + return time.Duration(duration) * time.Millisecond, l, nil +} + +func writeUploadSummary(w io.Writer, duration time.Duration, l uint32) error { + buf := make([]byte, 8) + binary.BigEndian.PutUint32(buf, uint32(duration/time.Millisecond)) + binary.BigEndian.PutUint32(buf[4:], l) + _, err := w.Write(buf) + return err +} diff --git a/extras/outbounds/speedtest/protocol_test.go b/extras/outbounds/speedtest/protocol_test.go new file mode 100644 index 0000000..1ad23a4 --- /dev/null +++ b/extras/outbounds/speedtest/protocol_test.go @@ -0,0 +1,446 @@ +package speedtest + +import ( + "bytes" + "testing" + "time" +) + +func TestReadDownloadRequest(t *testing.T) { + tests := []struct { + name string + data []byte + want uint32 + wantErr bool + }{ + { + name: "normal", + data: []byte{0x0, 0x1, 0xBD, 0xC2}, + want: 114114, + wantErr: false, + }, + { + name: "normal zero", + data: []byte{0x0, 0x0, 0x0, 0x0}, + want: 0, + wantErr: false, + }, + { + name: "incomplete", + data: []byte{0x0, 0x1, 0x2}, + want: 0, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, err := readDownloadRequest(r) + if (err != nil) != tt.wantErr { + t.Errorf("readDownloadRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("readDownloadRequest() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteDownloadRequest(t *testing.T) { + tests := []struct { + name string + l uint32 + wantW string + wantErr bool + }{ + { + name: "normal", + l: 78909912, + wantW: "\x01\x04\xB4\x11\xD8", + wantErr: false, + }, + { + name: "normal zero", + l: 0, + wantW: "\x01\x00\x00\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := writeDownloadRequest(w, tt.l) + if (err != nil) != tt.wantErr { + t.Errorf("writeDownloadRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeDownloadRequest() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func TestReadDownloadResponse(t *testing.T) { + tests := []struct { + name string + data []byte + want bool + want1 string + wantErr bool + }{ + { + name: "normal ok", + data: []byte{0x0, 0x0, 0x2, 0x41, 0x42}, + want: true, + want1: "AB", + wantErr: false, + }, + { + name: "normal ok no message", + data: []byte{0x0, 0x0, 0x0, 0x0}, + want: true, + want1: "", + wantErr: false, + }, + { + name: "normal error", + data: []byte{0x1, 0x0, 0x2, 0x43, 0x44}, + want: false, + want1: "CD", + wantErr: false, + }, + { + name: "incomplete", + data: []byte{0x0, 0x99, 0x99, 0x45, 0x46, 0x47}, + want: false, + want1: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, got1, err := readDownloadResponse(r) + if (err != nil) != tt.wantErr { + t.Errorf("readDownloadResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("readDownloadResponse() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("readDownloadResponse() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestWriteDownloadResponse(t *testing.T) { + type args struct { + ok bool + msg string + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + { + name: "normal ok", + args: args{ok: true, msg: "wahaha"}, + wantW: "\x00\x00\x06wahaha", + wantErr: false, + }, + { + name: "normal error", + args: args{ok: false, msg: "bullbull"}, + wantW: "\x01\x00\x08bullbull", + wantErr: false, + }, + { + name: "empty ok", + args: args{ok: true, msg: ""}, + wantW: "\x00\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := writeDownloadResponse(w, tt.args.ok, tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("writeDownloadResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeDownloadResponse() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func TestReadUploadRequest(t *testing.T) { + tests := []struct { + name string + data []byte + want uint32 + wantErr bool + }{ + { + name: "normal", + data: []byte{0x0, 0x0, 0x26, 0xEE}, + want: 9966, + wantErr: false, + }, + { + name: "normal zero", + data: []byte{0x0, 0x0, 0x0, 0x0}, + want: 0, + wantErr: false, + }, + { + name: "incomplete", + data: []byte{0x1}, + want: 0, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, err := readUploadRequest(r) + if (err != nil) != tt.wantErr { + t.Errorf("readUploadRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("readUploadRequest() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteUploadRequest(t *testing.T) { + tests := []struct { + name string + l uint32 + wantW string + wantErr bool + }{ + { + name: "normal", + l: 2291758882, + wantW: "\x02\x88\x99\x77\x22", + wantErr: false, + }, + { + name: "normal zero", + l: 0, + wantW: "\x02\x00\x00\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := writeUploadRequest(w, tt.l) + if (err != nil) != tt.wantErr { + t.Errorf("writeUploadRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeUploadRequest() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func TestReadUploadResponse(t *testing.T) { + tests := []struct { + name string + data []byte + want bool + want1 string + wantErr bool + }{ + { + name: "normal ok", + data: []byte{0x0, 0x0, 0x2, 0x41, 0x42}, + want: true, + want1: "AB", + wantErr: false, + }, + { + name: "normal ok no message", + data: []byte{0x0, 0x0, 0x0}, + want: true, + want1: "", + wantErr: false, + }, + { + name: "normal error", + data: []byte{0x1, 0x0, 0x2, 0x43, 0x44}, + want: false, + want1: "CD", + wantErr: false, + }, + { + name: "incomplete", + data: []byte{0x0, 0x99, 0x99, 0x45, 0x46, 0x47}, + want: false, + want1: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, got1, err := readUploadResponse(r) + if (err != nil) != tt.wantErr { + t.Errorf("readUploadResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("readUploadResponse() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("readUploadResponse() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestWriteUploadResponse(t *testing.T) { + type args struct { + ok bool + msg string + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + { + name: "normal ok", + args: args{ok: true, msg: "lul"}, + wantW: "\x00\x00\x03lul", + wantErr: false, + }, + { + name: "normal error", + args: args{ok: false, msg: "notforu"}, + wantW: "\x01\x00\x07notforu", + wantErr: false, + }, + { + name: "empty ok", + args: args{ok: true, msg: ""}, + wantW: "\x00\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := writeUploadResponse(w, tt.args.ok, tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("writeUploadResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeUploadResponse() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func TestReadUploadSummary(t *testing.T) { + tests := []struct { + name string + data []byte + want time.Duration + want1 uint32 + wantErr bool + }{ + { + name: "normal", + data: []byte{0x0, 0x0, 0x14, 0x6E, 0x0, 0x26, 0x25, 0xA0}, + want: 5230 * time.Millisecond, + want1: 2500000, + wantErr: false, + }, + { + name: "zero", + data: []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + want: 0, + want1: 0, + wantErr: false, + }, + { + name: "incomplete", + data: []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + want: 0, + want1: 0, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, got1, err := readUploadSummary(r) + if (err != nil) != tt.wantErr { + t.Errorf("readUploadSummary() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("readUploadSummary() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("readUploadSummary() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestWriteUploadSummary(t *testing.T) { + type args struct { + duration time.Duration + l uint32 + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + { + name: "normal", + args: args{duration: 5230 * time.Millisecond, l: 2500000}, + wantW: "\x00\x00\x14\x6E\x00\x26\x25\xA0", + wantErr: false, + }, + { + name: "zero", + args: args{duration: 0, l: 0}, + wantW: "\x00\x00\x00\x00\x00\x00\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := writeUploadSummary(w, tt.args.duration, tt.args.l) + if (err != nil) != tt.wantErr { + t.Errorf("writeUploadSummary() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("writeUploadSummary() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} diff --git a/extras/outbounds/speedtest/server.go b/extras/outbounds/speedtest/server.go new file mode 100644 index 0000000..d280b57 --- /dev/null +++ b/extras/outbounds/speedtest/server.go @@ -0,0 +1,99 @@ +package speedtest + +import ( + "crypto/rand" + "fmt" + "io" + "net" + "time" +) + +const ( + chunkSize = 64 * 1024 +) + +// NewServerConn creates a new "pseudo" connection that implements the speed test protocol. +// It's called "pseudo" because it's not a real TCP connection - everything is done in memory. +func NewServerConn() net.Conn { + rConn, iConn := net.Pipe() // return conn & internal conn + // Start the server logic + go server(iConn) + return rConn +} + +func server(conn net.Conn) error { + defer conn.Close() + // First byte determines the request type + var typ [1]byte + if _, err := io.ReadFull(conn, typ[:]); err != nil { + return err + } + switch typ[0] { + case typeDownload: + return handleDownload(conn) + case typeUpload: + return handleUpload(conn) + default: + return fmt.Errorf("unknown request type: %d", typ[0]) + } +} + +// handleDownload reads the download request and sends the requested amount of data. +func handleDownload(conn net.Conn) error { + l, err := readDownloadRequest(conn) + if err != nil { + return err + } + err = writeDownloadResponse(conn, true, "OK") + if err != nil { + return err + } + buf := make([]byte, chunkSize) + // Fill the buffer with random data. + // For now, we only do it once and repeat the same data for performance reasons. + _, err = rand.Read(buf) + if err != nil { + return err + } + remaining := l + for remaining > 0 { + n := remaining + if n > chunkSize { + n = chunkSize + } + _, err := conn.Write(buf[:n]) + if err != nil { + return err + } + remaining -= n + } + return nil +} + +// handleUpload reads the upload request, reads & discards the requested amount of data, +// and sends the upload summary. +func handleUpload(conn net.Conn) error { + l, err := readUploadRequest(conn) + if err != nil { + return err + } + err = writeUploadResponse(conn, true, "OK") + if err != nil { + return err + } + buf := make([]byte, chunkSize) + startTime := time.Now() + remaining := l + for remaining > 0 { + n := remaining + if n > chunkSize { + n = chunkSize + } + rn, err := conn.Read(buf[:n]) + remaining -= uint32(rn) + if err != nil && !(remaining == 0 && err == io.EOF) { + return err + } + } + return writeUploadSummary(conn, time.Since(startTime), l) +}