feat: built-in speed test client & server

This commit is contained in:
Toby 2024-03-09 20:38:30 -08:00
parent 84d72ef0b3
commit a0bd58063b
9 changed files with 1017 additions and 0 deletions

View file

@ -52,6 +52,7 @@ type serverConfig struct {
QUIC serverConfigQUIC `mapstructure:"quic"` QUIC serverConfigQUIC `mapstructure:"quic"`
Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"`
IgnoreClientBandwidth bool `mapstructure:"ignoreClientBandwidth"` IgnoreClientBandwidth bool `mapstructure:"ignoreClientBandwidth"`
SpeedTest bool `mapstructure:"speedTest"`
DisableUDP bool `mapstructure:"disableUDP"` DisableUDP bool `mapstructure:"disableUDP"`
UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"`
Auth serverConfigAuth `mapstructure:"auth"` Auth serverConfigAuth `mapstructure:"auth"`
@ -528,6 +529,11 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")} return configError{Field: "resolver.type", Err: errors.New("unsupported resolver type")}
} }
// Speed test
if c.SpeedTest {
uOb = outbounds.NewSpeedtestHandler(uOb)
}
hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb} hyConfig.Outbound = &outbounds.PluggableOutboundAdapter{PluggableOutbound: uOb}
return nil return nil
} }

View file

@ -56,6 +56,7 @@ func TestServerConfig(t *testing.T) {
Down: "100 mbps", Down: "100 mbps",
}, },
IgnoreClientBandwidth: true, IgnoreClientBandwidth: true,
SpeedTest: true,
DisableUDP: true, DisableUDP: true,
UDPIdleTimeout: 120 * time.Second, UDPIdleTimeout: 120 * time.Second,
Auth: serverConfigAuth{ Auth: serverConfigAuth{

View file

@ -36,6 +36,8 @@ bandwidth:
ignoreClientBandwidth: true ignoreClientBandwidth: true
speedTest: true
disableUDP: true disableUDP: true
udpIdleTimeout: 120s udpIdleTimeout: 120s

150
app/cmd/speedtest.go Normal file
View file

@ -0,0 +1,150 @@
package cmd
import (
"fmt"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
"github.com/apernet/hysteria/core/client"
"github.com/apernet/hysteria/extras/outbounds"
"github.com/apernet/hysteria/extras/outbounds/speedtest"
)
var (
skipDownload bool
skipUpload bool
dataSize uint32
useBytes bool
speedtestAddr = fmt.Sprintf("%s:%d", outbounds.SpeedtestDest, 0)
)
// speedtestCmd represents the speedtest command
var speedtestCmd = &cobra.Command{
Use: "speedtest",
Short: "Speed test mode",
Long: "Perform a speed test through the proxy server. The server must have speed test support enabled.",
Run: runSpeedtest,
}
func init() {
initSpeedtestFlags()
rootCmd.AddCommand(speedtestCmd)
}
func initSpeedtestFlags() {
speedtestCmd.Flags().BoolVar(&skipDownload, "skip-download", false, "Skip download test")
speedtestCmd.Flags().BoolVar(&skipUpload, "skip-upload", false, "Skip upload test")
speedtestCmd.Flags().Uint32Var(&dataSize, "data-size", 1024*1024*100, "Data size for download and upload tests")
speedtestCmd.Flags().BoolVar(&useBytes, "use-bytes", false, "Use bytes per second instead of bits per second")
}
func runSpeedtest(cmd *cobra.Command, args []string) {
logger.Info("speed test mode")
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read client config", zap.Error(err))
}
var config clientConfig
if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse client config", zap.Error(err))
}
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to load client config", zap.Error(err))
}
c, info, err := client.NewClient(hyConfig)
if err != nil {
logger.Fatal("failed to initialize client", zap.Error(err))
}
defer c.Close()
logger.Info("connected to server",
zap.Bool("udpEnabled", info.UDPEnabled),
zap.Uint64("tx", info.Tx))
if !skipDownload {
runDownloadTest(c)
}
if !skipUpload {
runUploadTest(c)
}
}
func runDownloadTest(c client.Client) {
logger.Info("performing download test")
downConn, err := c.TCP(speedtestAddr)
if err != nil {
logger.Fatal("failed to connect", zap.Error(err))
}
defer downConn.Close()
downClient := &speedtest.Client{Conn: downConn}
currentTotal := uint32(0)
err = downClient.Download(dataSize, func(d time.Duration, b uint32, done bool) {
if !done {
currentTotal += b
logger.Info("downloading",
zap.Uint32("bytes", b),
zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)),
zap.String("speed", formatSpeed(b, d, useBytes)))
} else {
logger.Info("download complete",
zap.Uint32("bytes", b),
zap.String("speed", formatSpeed(b, d, useBytes)))
}
})
if err != nil {
logger.Fatal("download test failed", zap.Error(err))
}
logger.Info("download test complete")
}
func runUploadTest(c client.Client) {
logger.Info("performing upload test")
upConn, err := c.TCP(speedtestAddr)
if err != nil {
logger.Fatal("failed to connect", zap.Error(err))
}
defer upConn.Close()
upClient := &speedtest.Client{Conn: upConn}
currentTotal := uint32(0)
err = upClient.Upload(dataSize, func(d time.Duration, b uint32, done bool) {
if !done {
currentTotal += b
logger.Info("uploading",
zap.Uint32("bytes", b),
zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)),
zap.String("speed", formatSpeed(b, d, useBytes)))
} else {
logger.Info("upload complete",
zap.Uint32("bytes", b),
zap.String("speed", formatSpeed(b, d, useBytes)))
}
})
if err != nil {
logger.Fatal("upload test failed", zap.Error(err))
}
logger.Info("upload test complete")
}
func formatSpeed(bytes uint32, duration time.Duration, useBytes bool) string {
speed := float64(bytes) / duration.Seconds()
var units []string
if useBytes {
units = []string{"B/s", "KB/s", "MB/s", "GB/s"}
} else {
units = []string{"bps", "Kbps", "Mbps", "Gbps"}
speed *= 8
}
unitIndex := 0
for speed > 1024 && unitIndex < len(units)-1 {
speed /= 1024
unitIndex++
}
return fmt.Sprintf("%.2f %s", speed, units[unitIndex])
}

View file

@ -0,0 +1,36 @@
package outbounds
import (
"net"
"github.com/apernet/hysteria/extras/outbounds/speedtest"
)
const (
SpeedtestDest = "_SpeedTest"
)
// speedtestHandler is a PluggableOutbound that handles speed test requests.
// It's used to intercept speed test requests and return a pseudo connection that
// implements the speed test protocol.
type speedtestHandler struct {
Next PluggableOutbound
}
func NewSpeedtestHandler(next PluggableOutbound) PluggableOutbound {
return &speedtestHandler{
Next: next,
}
}
func (s *speedtestHandler) TCP(reqAddr *AddrEx) (net.Conn, error) {
if reqAddr.Host == SpeedtestDest {
return speedtest.NewServerConn(), nil
} else {
return s.Next.TCP(reqAddr)
}
}
func (s *speedtestHandler) UDP(reqAddr *AddrEx) (UDPConn, error) {
return s.Next.UDP(reqAddr)
}

View file

@ -0,0 +1,125 @@
package speedtest
import (
"fmt"
"io"
"net"
"sync/atomic"
"time"
)
type Client struct {
Conn net.Conn
}
// Download requests the server to send l bytes of data.
// The callback function cb is called every second with the time since the last call,
// and the number of bytes received in that time.
func (c *Client) Download(l uint32, cb func(time.Duration, uint32, bool)) error {
err := writeDownloadRequest(c.Conn, l)
if err != nil {
return err
}
ok, msg, err := readDownloadResponse(c.Conn)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("server rejected download request: %s", msg)
}
var counter uint32
stopChan := make(chan struct{})
defer close(stopChan)
// Call the callback function every second,
// with the time since the last call and the number of bytes received in that time.
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
t := time.Now()
for {
select {
case <-stopChan:
return
case <-ticker.C:
cb(time.Since(t), atomic.SwapUint32(&counter, 0), false)
t = time.Now()
}
}
}()
buf := make([]byte, chunkSize)
startTime := time.Now()
remaining := l
for remaining > 0 {
n := remaining
if n > chunkSize {
n = chunkSize
}
rn, err := c.Conn.Read(buf[:n])
remaining -= uint32(rn)
atomic.AddUint32(&counter, uint32(rn))
if err != nil && !(remaining == 0 && err == io.EOF) {
return err
}
}
// One last call to the callback function to report the total time and bytes received.
cb(time.Since(startTime), l, true)
return nil
}
// Upload requests the server to receive l bytes of data.
// The callback function cb is called every second with the time since the last call,
// and the number of bytes sent in that time.
func (c *Client) Upload(l uint32, cb func(time.Duration, uint32, bool)) error {
err := writeUploadRequest(c.Conn, l)
if err != nil {
return err
}
ok, msg, err := readUploadResponse(c.Conn)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("server rejected upload request: %s", msg)
}
var counter uint32
stopChan := make(chan struct{})
defer close(stopChan)
// Call the callback function every second,
// with the time since the last call and the number of bytes sent in that time.
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
t := time.Now()
for {
select {
case <-stopChan:
return
case <-ticker.C:
cb(time.Since(t), atomic.SwapUint32(&counter, 0), false)
t = time.Now()
}
}
}()
buf := make([]byte, chunkSize)
remaining := l
for remaining > 0 {
n := remaining
if n > chunkSize {
n = chunkSize
}
_, err := c.Conn.Write(buf[:n])
if err != nil {
return err
}
remaining -= n
atomic.AddUint32(&counter, n)
}
// Now we should receive the upload summary from the server.
elapsed, received, err := readUploadSummary(c.Conn)
if err != nil {
return err
}
// One last call to the callback function to report the total time and bytes sent.
cb(elapsed, received, true)
return nil
}

View file

@ -0,0 +1,152 @@
package speedtest
import (
"encoding/binary"
"io"
"time"
)
const (
typeDownload = 0x1
typeUpload = 0x2
)
// DownloadRequest format:
// 0x1 (byte)
// Request data length (uint32 BE)
func readDownloadRequest(r io.Reader) (uint32, error) {
var l uint32
err := binary.Read(r, binary.BigEndian, &l)
return l, err
}
func writeDownloadRequest(w io.Writer, l uint32) error {
buf := make([]byte, 5)
buf[0] = typeDownload
binary.BigEndian.PutUint32(buf[1:], l)
_, err := w.Write(buf)
return err
}
// DownloadResponse format:
// Status (byte, 0=ok, 1=error)
// Message length (uint16 BE)
// Message (bytes)
func readDownloadResponse(r io.Reader) (bool, string, error) {
var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil {
return false, "", err
}
var msgLen uint16
if err := binary.Read(r, binary.BigEndian, &msgLen); err != nil {
return false, "", err
}
// No message is fine
if msgLen == 0 {
return status[0] == 0, "", nil
}
msgBuf := make([]byte, msgLen)
_, err := io.ReadFull(r, msgBuf)
if err != nil {
return false, "", err
}
return status[0] == 0, string(msgBuf), nil
}
func writeDownloadResponse(w io.Writer, ok bool, msg string) error {
sz := 1 + 2 + len(msg)
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
binary.BigEndian.PutUint16(buf[1:], uint16(len(msg)))
copy(buf[3:], msg)
_, err := w.Write(buf)
return err
}
// UploadRequest format:
// 0x2 (byte)
// Upload data length (uint32 BE)
func readUploadRequest(r io.Reader) (uint32, error) {
var l uint32
err := binary.Read(r, binary.BigEndian, &l)
return l, err
}
func writeUploadRequest(w io.Writer, l uint32) error {
buf := make([]byte, 5)
buf[0] = typeUpload
binary.BigEndian.PutUint32(buf[1:], l)
_, err := w.Write(buf)
return err
}
// UploadResponse format:
// Status (byte, 0=ok, 1=error)
// Message length (uint16 BE)
// Message (bytes)
func readUploadResponse(r io.Reader) (bool, string, error) {
var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil {
return false, "", err
}
var msgLen uint16
if err := binary.Read(r, binary.BigEndian, &msgLen); err != nil {
return false, "", err
}
// No message is fine
if msgLen == 0 {
return status[0] == 0, "", nil
}
msgBuf := make([]byte, msgLen)
_, err := io.ReadFull(r, msgBuf)
if err != nil {
return false, "", err
}
return status[0] == 0, string(msgBuf), nil
}
func writeUploadResponse(w io.Writer, ok bool, msg string) error {
sz := 1 + 2 + len(msg)
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
binary.BigEndian.PutUint16(buf[1:], uint16(len(msg)))
copy(buf[3:], msg)
_, err := w.Write(buf)
return err
}
// UploadSummary format:
// Duration (in milliseconds, uint32 BE)
// Received data length (uint32 BE)
func readUploadSummary(r io.Reader) (time.Duration, uint32, error) {
var duration uint32
if err := binary.Read(r, binary.BigEndian, &duration); err != nil {
return 0, 0, err
}
var l uint32
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return 0, 0, err
}
return time.Duration(duration) * time.Millisecond, l, nil
}
func writeUploadSummary(w io.Writer, duration time.Duration, l uint32) error {
buf := make([]byte, 8)
binary.BigEndian.PutUint32(buf, uint32(duration/time.Millisecond))
binary.BigEndian.PutUint32(buf[4:], l)
_, err := w.Write(buf)
return err
}

View file

@ -0,0 +1,446 @@
package speedtest
import (
"bytes"
"testing"
"time"
)
func TestReadDownloadRequest(t *testing.T) {
tests := []struct {
name string
data []byte
want uint32
wantErr bool
}{
{
name: "normal",
data: []byte{0x0, 0x1, 0xBD, 0xC2},
want: 114114,
wantErr: false,
},
{
name: "normal zero",
data: []byte{0x0, 0x0, 0x0, 0x0},
want: 0,
wantErr: false,
},
{
name: "incomplete",
data: []byte{0x0, 0x1, 0x2},
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, err := readDownloadRequest(r)
if (err != nil) != tt.wantErr {
t.Errorf("readDownloadRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("readDownloadRequest() got = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteDownloadRequest(t *testing.T) {
tests := []struct {
name string
l uint32
wantW string
wantErr bool
}{
{
name: "normal",
l: 78909912,
wantW: "\x01\x04\xB4\x11\xD8",
wantErr: false,
},
{
name: "normal zero",
l: 0,
wantW: "\x01\x00\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := writeDownloadRequest(w, tt.l)
if (err != nil) != tt.wantErr {
t.Errorf("writeDownloadRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeDownloadRequest() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadDownloadResponse(t *testing.T) {
tests := []struct {
name string
data []byte
want bool
want1 string
wantErr bool
}{
{
name: "normal ok",
data: []byte{0x0, 0x0, 0x2, 0x41, 0x42},
want: true,
want1: "AB",
wantErr: false,
},
{
name: "normal ok no message",
data: []byte{0x0, 0x0, 0x0, 0x0},
want: true,
want1: "",
wantErr: false,
},
{
name: "normal error",
data: []byte{0x1, 0x0, 0x2, 0x43, 0x44},
want: false,
want1: "CD",
wantErr: false,
},
{
name: "incomplete",
data: []byte{0x0, 0x99, 0x99, 0x45, 0x46, 0x47},
want: false,
want1: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, got1, err := readDownloadResponse(r)
if (err != nil) != tt.wantErr {
t.Errorf("readDownloadResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("readDownloadResponse() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("readDownloadResponse() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestWriteDownloadResponse(t *testing.T) {
type args struct {
ok bool
msg string
}
tests := []struct {
name string
args args
wantW string
wantErr bool
}{
{
name: "normal ok",
args: args{ok: true, msg: "wahaha"},
wantW: "\x00\x00\x06wahaha",
wantErr: false,
},
{
name: "normal error",
args: args{ok: false, msg: "bullbull"},
wantW: "\x01\x00\x08bullbull",
wantErr: false,
},
{
name: "empty ok",
args: args{ok: true, msg: ""},
wantW: "\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := writeDownloadResponse(w, tt.args.ok, tt.args.msg)
if (err != nil) != tt.wantErr {
t.Errorf("writeDownloadResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeDownloadResponse() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadUploadRequest(t *testing.T) {
tests := []struct {
name string
data []byte
want uint32
wantErr bool
}{
{
name: "normal",
data: []byte{0x0, 0x0, 0x26, 0xEE},
want: 9966,
wantErr: false,
},
{
name: "normal zero",
data: []byte{0x0, 0x0, 0x0, 0x0},
want: 0,
wantErr: false,
},
{
name: "incomplete",
data: []byte{0x1},
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, err := readUploadRequest(r)
if (err != nil) != tt.wantErr {
t.Errorf("readUploadRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("readUploadRequest() got = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteUploadRequest(t *testing.T) {
tests := []struct {
name string
l uint32
wantW string
wantErr bool
}{
{
name: "normal",
l: 2291758882,
wantW: "\x02\x88\x99\x77\x22",
wantErr: false,
},
{
name: "normal zero",
l: 0,
wantW: "\x02\x00\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := writeUploadRequest(w, tt.l)
if (err != nil) != tt.wantErr {
t.Errorf("writeUploadRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeUploadRequest() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadUploadResponse(t *testing.T) {
tests := []struct {
name string
data []byte
want bool
want1 string
wantErr bool
}{
{
name: "normal ok",
data: []byte{0x0, 0x0, 0x2, 0x41, 0x42},
want: true,
want1: "AB",
wantErr: false,
},
{
name: "normal ok no message",
data: []byte{0x0, 0x0, 0x0},
want: true,
want1: "",
wantErr: false,
},
{
name: "normal error",
data: []byte{0x1, 0x0, 0x2, 0x43, 0x44},
want: false,
want1: "CD",
wantErr: false,
},
{
name: "incomplete",
data: []byte{0x0, 0x99, 0x99, 0x45, 0x46, 0x47},
want: false,
want1: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, got1, err := readUploadResponse(r)
if (err != nil) != tt.wantErr {
t.Errorf("readUploadResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("readUploadResponse() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("readUploadResponse() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestWriteUploadResponse(t *testing.T) {
type args struct {
ok bool
msg string
}
tests := []struct {
name string
args args
wantW string
wantErr bool
}{
{
name: "normal ok",
args: args{ok: true, msg: "lul"},
wantW: "\x00\x00\x03lul",
wantErr: false,
},
{
name: "normal error",
args: args{ok: false, msg: "notforu"},
wantW: "\x01\x00\x07notforu",
wantErr: false,
},
{
name: "empty ok",
args: args{ok: true, msg: ""},
wantW: "\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := writeUploadResponse(w, tt.args.ok, tt.args.msg)
if (err != nil) != tt.wantErr {
t.Errorf("writeUploadResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeUploadResponse() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadUploadSummary(t *testing.T) {
tests := []struct {
name string
data []byte
want time.Duration
want1 uint32
wantErr bool
}{
{
name: "normal",
data: []byte{0x0, 0x0, 0x14, 0x6E, 0x0, 0x26, 0x25, 0xA0},
want: 5230 * time.Millisecond,
want1: 2500000,
wantErr: false,
},
{
name: "zero",
data: []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
want: 0,
want1: 0,
wantErr: false,
},
{
name: "incomplete",
data: []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
want: 0,
want1: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, got1, err := readUploadSummary(r)
if (err != nil) != tt.wantErr {
t.Errorf("readUploadSummary() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("readUploadSummary() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("readUploadSummary() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestWriteUploadSummary(t *testing.T) {
type args struct {
duration time.Duration
l uint32
}
tests := []struct {
name string
args args
wantW string
wantErr bool
}{
{
name: "normal",
args: args{duration: 5230 * time.Millisecond, l: 2500000},
wantW: "\x00\x00\x14\x6E\x00\x26\x25\xA0",
wantErr: false,
},
{
name: "zero",
args: args{duration: 0, l: 0},
wantW: "\x00\x00\x00\x00\x00\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := writeUploadSummary(w, tt.args.duration, tt.args.l)
if (err != nil) != tt.wantErr {
t.Errorf("writeUploadSummary() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("writeUploadSummary() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}

View file

@ -0,0 +1,99 @@
package speedtest
import (
"crypto/rand"
"fmt"
"io"
"net"
"time"
)
const (
chunkSize = 64 * 1024
)
// NewServerConn creates a new "pseudo" connection that implements the speed test protocol.
// It's called "pseudo" because it's not a real TCP connection - everything is done in memory.
func NewServerConn() net.Conn {
rConn, iConn := net.Pipe() // return conn & internal conn
// Start the server logic
go server(iConn)
return rConn
}
func server(conn net.Conn) error {
defer conn.Close()
// First byte determines the request type
var typ [1]byte
if _, err := io.ReadFull(conn, typ[:]); err != nil {
return err
}
switch typ[0] {
case typeDownload:
return handleDownload(conn)
case typeUpload:
return handleUpload(conn)
default:
return fmt.Errorf("unknown request type: %d", typ[0])
}
}
// handleDownload reads the download request and sends the requested amount of data.
func handleDownload(conn net.Conn) error {
l, err := readDownloadRequest(conn)
if err != nil {
return err
}
err = writeDownloadResponse(conn, true, "OK")
if err != nil {
return err
}
buf := make([]byte, chunkSize)
// Fill the buffer with random data.
// For now, we only do it once and repeat the same data for performance reasons.
_, err = rand.Read(buf)
if err != nil {
return err
}
remaining := l
for remaining > 0 {
n := remaining
if n > chunkSize {
n = chunkSize
}
_, err := conn.Write(buf[:n])
if err != nil {
return err
}
remaining -= n
}
return nil
}
// handleUpload reads the upload request, reads & discards the requested amount of data,
// and sends the upload summary.
func handleUpload(conn net.Conn) error {
l, err := readUploadRequest(conn)
if err != nil {
return err
}
err = writeUploadResponse(conn, true, "OK")
if err != nil {
return err
}
buf := make([]byte, chunkSize)
startTime := time.Now()
remaining := l
for remaining > 0 {
n := remaining
if n > chunkSize {
n = chunkSize
}
rn, err := conn.Read(buf[:n])
remaining -= uint32(rn)
if err != nil && !(remaining == 0 && err == io.EOF) {
return err
}
}
return writeUploadSummary(conn, time.Since(startTime), l)
}