Tons of refactoring

This commit is contained in:
Toby 2020-04-20 16:53:13 -07:00
parent 192e735f2a
commit a424a17af3
30 changed files with 1444 additions and 1456 deletions

3
.gitignore vendored
View file

@ -179,5 +179,4 @@ $RECYCLE.BIN/
# End of https://www.gitignore.io/api/go,linux,macos,windows,intellij+all
cmd/forwarder/*.json
cmd/forwarder/forwarder
cmd/relay/*.json

View file

@ -1,169 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"github.com/lucas-clemente/quic-go/congestion"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/forwarder"
"io/ioutil"
"log"
"net"
"os"
"os/user"
)
func loadCmdClientConfig(args []string) (CmdClientConfig, error) {
fs := flag.NewFlagSet("client", flag.ContinueOnError)
// Config file
configFile := fs.String("config", "", "Configuration file path")
// Listen
listen := fs.String("listen", "", "TCP listen address")
// Server
server := fs.String("server", "", "Server address")
// Name
name := fs.String("name", "", "Client name presented to the server")
// Insecure
var insecure optionalBoolFlag
fs.Var(&insecure, "insecure", "Ignore TLS certificate errors")
// Custom CA
customCAFile := fs.String("ca", "", "Specify a trusted CA file")
// Up Mbps
upMbps := fs.Int("up-mbps", 0, "Upload speed in Mbps")
// Down Mbps
downMbps := fs.Int("down-mbps", 0, "Download speed in Mbps")
// Receive window conn
recvWindowConn := fs.Uint64("recv-window-conn", 0, "Max receive window size per connection")
// Receive window
recvWindow := fs.Uint64("recv-window", 0, "Max receive window size")
// Parse
if err := fs.Parse(args); err != nil {
os.Exit(1)
}
// Put together the config
var config CmdClientConfig
// Load from file first
if len(*configFile) > 0 {
cb, err := ioutil.ReadFile(*configFile)
if err != nil {
return CmdClientConfig{}, err
}
if err := json.Unmarshal(cb, &config); err != nil {
return CmdClientConfig{}, err
}
}
// Then CLI options can override config
if len(*listen) > 0 {
config.ListenAddr = *listen
}
if len(*server) > 0 {
config.ServerAddr = *server
}
if len(*name) > 0 {
config.Name = *name
}
if insecure.Exists {
config.Insecure = insecure.Value
}
if len(*customCAFile) > 0 {
config.CustomCAFile = *customCAFile
}
if *upMbps != 0 {
config.UpMbps = *upMbps
}
if *downMbps != 0 {
config.DownMbps = *downMbps
}
if *recvWindowConn != 0 {
config.ReceiveWindowConn = *recvWindowConn
}
if *recvWindow != 0 {
config.ReceiveWindow = *recvWindow
}
return config, nil
}
func client(args []string) {
config, err := loadCmdClientConfig(args)
if err != nil {
log.Fatalln("Unable to load configuration:", err.Error())
}
if err := config.Check(); err != nil {
log.Fatalln("Configuration error:", err.Error())
}
if len(config.Name) == 0 {
usr, err := user.Current()
if err == nil {
config.Name = usr.Name
}
}
fmt.Printf("Configuration loaded: %+v\n", config)
tlsConfig := &tls.Config{
NextProtos: []string{forwarder.TLSAppProtocol},
MinVersion: tls.VersionTLS13,
}
// Load CA
if len(config.CustomCAFile) > 0 {
bs, err := ioutil.ReadFile(config.CustomCAFile)
if err != nil {
log.Fatalln("Unable to load CA file:", err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(bs) {
log.Fatalln("Unable to parse CA file", config.CustomCAFile)
}
tlsConfig.RootCAs = cp
}
logChan := make(chan string, 4)
go func() {
_, err = forwarder.NewClient(config.ListenAddr, config.ServerAddr, forwarder.ClientConfig{
Name: config.Name,
TLSConfig: tlsConfig,
Speed: &forwarder.Speed{
SendBPS: uint64(config.UpMbps) * mbpsToBps,
ReceiveBPS: uint64(config.DownMbps) * mbpsToBps,
},
MaxReceiveWindowPerConnection: config.ReceiveWindowConn,
MaxReceiveWindow: config.ReceiveWindow,
CongestionFactory: func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
},
}, forwarder.ClientCallbacks{
ServerConnectedCallback: func(addr net.Addr, banner string, cSend uint64, cRecv uint64) {
logChan <- fmt.Sprintf("Connected to server %s, negotiated speed in Mbps: Up %d / Down %d",
addr.String(), cSend/mbpsToBps, cRecv/mbpsToBps)
logChan <- fmt.Sprintf("Server banner: [%s]", banner)
},
ServerErrorCallback: func(err error) {
logChan <- fmt.Sprintf("Error connecting to the server: %s", err.Error())
},
NewTCPConnectionCallback: func(addr net.Addr) {
logChan <- fmt.Sprintf("New connection: %s", addr.String())
},
TCPConnectionClosedCallback: func(addr net.Addr, err error) {
logChan <- fmt.Sprintf("Connection %s closed: %s", addr.String(), err.Error())
},
})
if err != nil {
log.Fatalln("Client startup failure:", err)
} else {
log.Println("The client is now up and running :)")
}
}()
for {
logStr := <-logChan
if len(logStr) == 0 {
break
}
log.Println(logStr)
}
}

View file

@ -1,76 +0,0 @@
package main
import (
"errors"
"fmt"
)
type CmdClientConfig struct {
ListenAddr string `json:"listen"`
ServerAddr string `json:"server"`
Name string `json:"name"`
Insecure bool `json:"insecure"`
CustomCAFile string `json:"ca"`
UpMbps int `json:"up_mbps"`
DownMbps int `json:"down_mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindow uint64 `json:"recv_window"`
}
func (c *CmdClientConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.ServerAddr) == 0 {
return errors.New("no server address")
}
if c.UpMbps <= 0 || c.DownMbps <= 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
return errors.New("invalid receive window size")
}
return nil
}
type ForwardEntry struct {
ListenAddr string `json:"listen"`
RemoteAddr string `json:"remote"`
}
func (e *ForwardEntry) String() string {
return fmt.Sprintf("%s <-> %s", e.ListenAddr, e.RemoteAddr)
}
type CmdServerConfig struct {
Entries []ForwardEntry `json:"entries"`
Banner string `json:"banner"`
CertFile string `json:"cert"`
KeyFile string `json:"key"`
UpMbps int `json:"up_mbps"`
DownMbps int `json:"down_mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindowClient uint64 `json:"recv_window_client"`
MaxConnClient int `json:"max_conn_client"`
}
func (c *CmdServerConfig) Check() error {
if len(c.Entries) == 0 {
return errors.New("no entries")
}
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
return errors.New("TLS cert or key not provided")
}
if c.UpMbps < 0 || c.DownMbps < 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}

View file

@ -1,204 +0,0 @@
package main
import (
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"github.com/lucas-clemente/quic-go/congestion"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/forwarder"
"io/ioutil"
"log"
"net"
"os"
"strings"
)
const mbpsToBps = 125000
func loadCmdServerConfig(args []string) (CmdServerConfig, error) {
fs := flag.NewFlagSet("server", flag.ContinueOnError)
// Config file
configFile := fs.String("config", "", "Configuration file path")
// Entries
var entries stringSliceFlag
fs.Var(&entries, "entry", "Add a forwarding entry. Separate the listen address and the remote address with a comma. You can add this option multiple times. Example: localhost:444,google.com:443")
// Banner
banner := fs.String("banner", "", "A banner to present to clients")
// Cert file
certFile := fs.String("cert", "", "TLS certificate file")
// Key file
keyFile := fs.String("key", "", "TLS key file")
// Up Mbps
upMbps := fs.Int("up-mbps", 0, "Max upload speed per client in Mbps")
// Down Mbps
downMbps := fs.Int("down-mbps", 0, "Max download speed per client in Mbps")
// Receive window conn
recvWindowConn := fs.Uint64("recv-window-conn", 0, "Max receive window size per connection")
// Receive window client
recvWindowClient := fs.Uint64("recv-window-client", 0, "Max receive window size per client")
// Max conn client
maxConnClient := fs.Int("max-conn-client", 0, "Max simultaneous connections allowed per client")
// Parse
if err := fs.Parse(args); err != nil {
os.Exit(1)
}
// Put together the config
var config CmdServerConfig
// Load from file first
if len(*configFile) > 0 {
cb, err := ioutil.ReadFile(*configFile)
if err != nil {
return CmdServerConfig{}, err
}
if err := json.Unmarshal(cb, &config); err != nil {
return CmdServerConfig{}, err
}
}
// Then CLI options can override config
if len(entries) > 0 {
fe, err := flagToEntries(entries)
if err != nil {
return CmdServerConfig{}, err
}
config.Entries = append(config.Entries, fe...)
}
if len(*banner) > 0 {
config.Banner = *banner
}
if len(*certFile) > 0 {
config.CertFile = *certFile
}
if len(*keyFile) > 0 {
config.KeyFile = *keyFile
}
if *upMbps != 0 {
config.UpMbps = *upMbps
}
if *downMbps != 0 {
config.DownMbps = *downMbps
}
if *recvWindowConn != 0 {
config.ReceiveWindowConn = *recvWindowConn
}
if *recvWindowClient != 0 {
config.ReceiveWindowClient = *recvWindowClient
}
if *maxConnClient != 0 {
config.MaxConnClient = *maxConnClient
}
return config, nil
}
func flagToEntries(f stringSliceFlag) ([]ForwardEntry, error) {
out := make([]ForwardEntry, len(f))
for i, entry := range f {
es := strings.Split(entry, ",")
if len(es) != 2 {
return nil, fmt.Errorf("incorrect entry syntax: %s", entry)
}
out[i] = ForwardEntry{
ListenAddr: es[0],
RemoteAddr: es[1],
}
}
return out, nil
}
func server(args []string) {
config, err := loadCmdServerConfig(args)
if err != nil {
log.Fatalln("Unable to load configuration:", err.Error())
}
if err := config.Check(); err != nil {
log.Fatalln("Configuration error:", err.Error())
}
fmt.Printf("Configuration loaded: %+v\n", config)
// Load cert
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
log.Fatalln("Unable to load the certificate:", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{forwarder.TLSAppProtocol},
MinVersion: tls.VersionTLS13,
}
logChan := make(chan string, 4)
go func() {
server := forwarder.NewServer(forwarder.ServerConfig{
BannerMessage: config.Banner,
TLSConfig: tlsConfig,
MaxSpeedPerClient: &forwarder.Speed{
SendBPS: uint64(config.UpMbps) * mbpsToBps,
ReceiveBPS: uint64(config.DownMbps) * mbpsToBps,
},
MaxReceiveWindowPerConnection: config.ReceiveWindowConn,
MaxReceiveWindowPerClient: config.ReceiveWindowClient,
MaxConnectionPerClient: config.MaxConnClient,
CongestionFactory: func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
},
}, forwarder.ServerCallbacks{
ClientConnectedCallback: func(listenAddr string, clientAddr net.Addr, name string, sSend uint64, sRecv uint64) {
if len(name) > 0 {
logChan <- fmt.Sprintf("[%s] Client %s (%s) connected, negotiated speed in Mbps: Up %d / Down %d",
listenAddr, clientAddr.String(), name, sSend/mbpsToBps, sRecv/mbpsToBps)
} else {
logChan <- fmt.Sprintf("[%s] Client %s connected, negotiated speed in Mbps: Up %d / Down %d",
listenAddr, clientAddr.String(), sSend/mbpsToBps, sRecv/mbpsToBps)
}
},
ClientDisconnectedCallback: func(listenAddr string, clientAddr net.Addr, name string, err error) {
if len(name) > 0 {
logChan <- fmt.Sprintf("[%s] Client %s (%s) disconnected: %s",
listenAddr, clientAddr.String(), name, err.Error())
} else {
logChan <- fmt.Sprintf("[%s] Client %s disconnected: %s",
listenAddr, clientAddr.String(), err.Error())
}
},
ClientNewStreamCallback: func(listenAddr string, clientAddr net.Addr, name string, id int) {
if len(name) > 0 {
logChan <- fmt.Sprintf("[%s] Client %s (%s) opened stream ID %d",
listenAddr, clientAddr.String(), name, id)
} else {
logChan <- fmt.Sprintf("[%s] Client %s opened stream ID %d",
listenAddr, clientAddr.String(), id)
}
},
ClientStreamClosedCallback: func(listenAddr string, clientAddr net.Addr, name string, id int, err error) {
if len(name) > 0 {
logChan <- fmt.Sprintf("[%s] Client %s (%s) closed stream ID %d: %s",
listenAddr, clientAddr.String(), name, id, err.Error())
} else {
logChan <- fmt.Sprintf("[%s] Client %s closed stream ID %d: %s",
listenAddr, clientAddr.String(), id, err.Error())
}
},
TCPErrorCallback: func(listenAddr string, remoteAddr string, err error) {
logChan <- fmt.Sprintf("[%s] TCP error when connecting to %s: %s",
listenAddr, remoteAddr, err.Error())
},
})
for _, entry := range config.Entries {
log.Println("Starting", entry.String(), "...")
if err := server.Add(entry.ListenAddr, entry.RemoteAddr); err != nil {
log.Fatalln(err)
}
}
log.Println("The server is now up and running :)")
}()
for {
logStr := <-logChan
if len(logStr) == 0 {
break
}
log.Println(logStr)
}
}

118
cmd/relay/client.go Normal file
View file

@ -0,0 +1,118 @@
package main
import (
"crypto/tls"
"crypto/x509"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/tobyxdd/hysteria/internal/utils"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"io"
"io/ioutil"
"log"
"net"
"os/user"
)
func client(args []string) {
var config cmdClientConfig
err := loadConfig(&config, args)
if err != nil {
log.Fatalln("Unable to load configuration:", err)
}
if err := config.Check(); err != nil {
log.Fatalln("Configuration error:", err)
}
if len(config.Name) == 0 {
usr, err := user.Current()
if err == nil {
config.Name = usr.Name
}
}
log.Printf("Configuration loaded: %+v\n", config)
tlsConfig := &tls.Config{
NextProtos: []string{TLSAppProtocol},
MinVersion: tls.VersionTLS13,
}
// Load CA
if len(config.CustomCAFile) > 0 {
bs, err := ioutil.ReadFile(config.CustomCAFile)
if err != nil {
log.Fatalln("Unable to load CA file:", err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(bs) {
log.Fatalln("Unable to parse CA file", config.CustomCAFile)
}
tlsConfig.RootCAs = cp
}
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow,
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
client, err := core.NewClient(config.ServerAddr, config.Name, "", tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
})
if err != nil {
log.Fatalln("Client initialization failed:", err)
}
defer client.Close()
log.Println("Client initialization complete, connected to", config.ServerAddr)
listener, err := net.Listen("tcp", config.ListenAddr)
if err != nil {
log.Fatalln("TCP listen failed:", err)
}
defer listener.Close()
log.Println("TCP listening on", listener.Addr().String())
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalln("TCP accept failed:", err)
}
go clientHandleConn(conn, client)
}
}
func clientHandleConn(conn net.Conn, client core.Client) {
log.Println("New TCP connection from", conn.RemoteAddr().String())
var closeErr error
defer func() {
_ = conn.Close()
log.Println("TCP connection from", conn.RemoteAddr().String(), "closed", closeErr)
}()
rwc, err := client.Dial(false, "")
if err != nil {
closeErr = err
return
}
defer rwc.Close()
closeErr = pipePair(conn, rwc)
}
func pipePair(rw1, rw2 io.ReadWriter) error {
// Pipes
errChan := make(chan error, 2)
go func() {
errChan <- utils.Pipe(rw2, rw1, nil)
}()
go func() {
errChan <- utils.Pipe(rw1, rw2, nil)
}()
// We only need the first error
return <-errChan
}

150
cmd/relay/config.go Normal file
View file

@ -0,0 +1,150 @@
package main
import (
"encoding/json"
"errors"
"flag"
"io/ioutil"
"os"
"reflect"
"strings"
)
const (
mbpsToBps = 125000
TLSAppProtocol = "hysteria-relay"
DefaultMaxReceiveStreamFlowControlWindow = 33554432
DefaultMaxReceiveConnectionFlowControlWindow = 67108864
)
type cmdClientConfig struct {
ListenAddr string `json:"listen" desc:"TCP listen address"`
ServerAddr string `json:"server" desc:"Server address"`
Name string `json:"name" desc:"Client name presented to the server"`
Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"`
CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"`
UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"`
DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"`
}
func (c *cmdClientConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.ServerAddr) == 0 {
return errors.New("no server address")
}
if c.UpMbps <= 0 || c.DownMbps <= 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
return errors.New("invalid receive window size")
}
return nil
}
type cmdServerConfig struct {
ListenAddr string `json:"listen" desc:"Server listen address"`
RemoteAddr string `json:"remote" desc:"Remote relay address"`
CertFile string `json:"cert" desc:"TLS certificate file"`
KeyFile string `json:"key" desc:"TLS key file"`
UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"`
DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"`
MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"`
}
func (c *cmdServerConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.RemoteAddr) == 0 {
return errors.New("no remote address")
}
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
return errors.New("TLS cert or key not provided")
}
if c.UpMbps < 0 || c.DownMbps < 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}
func loadConfig(cfg interface{}, args []string) error {
cfgVal := reflect.ValueOf(cfg).Elem()
fs := flag.NewFlagSet("", flag.ContinueOnError)
fsValMap := make(map[reflect.Value]interface{}, cfgVal.NumField())
for i := 0; i < cfgVal.NumField(); i++ {
structField := cfgVal.Type().Field(i)
tag := structField.Tag
switch structField.Type.Kind() {
case reflect.String:
fsValMap[cfgVal.Field(i)] =
fs.String(jsonTagToFlagName(tag.Get("json")), "", tag.Get("desc"))
case reflect.Int:
fsValMap[cfgVal.Field(i)] =
fs.Int(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
case reflect.Uint64:
fsValMap[cfgVal.Field(i)] =
fs.Uint64(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
case reflect.Bool:
var bf optionalBoolFlag
fs.Var(&bf, jsonTagToFlagName(tag.Get("json")), tag.Get("desc"))
fsValMap[cfgVal.Field(i)] = &bf
}
}
configFile := fs.String("config", "", "Configuration file")
// Parse
if err := fs.Parse(args); err != nil {
os.Exit(1)
}
// Put together the config
if len(*configFile) > 0 {
cb, err := ioutil.ReadFile(*configFile)
if err != nil {
return err
}
if err := json.Unmarshal(cb, cfg); err != nil {
return err
}
}
// Flags override config from file
for field, val := range fsValMap {
switch v := val.(type) {
case *string:
if len(*v) > 0 {
field.SetString(*v)
}
case *int:
if *v != 0 {
field.SetInt(int64(*v))
}
case *uint64:
if *v != 0 {
field.SetUint(*v)
}
case *optionalBoolFlag:
if v.Exists {
field.SetBool(v.Value)
}
}
}
return nil
}
func jsonTagToFlagName(tag string) string {
return strings.ReplaceAll(tag, "_", "-")
}

View file

@ -2,7 +2,6 @@ package main
import (
"strconv"
"strings"
)
type optionalBoolFlag struct {
@ -24,17 +23,6 @@ func (flag *optionalBoolFlag) Set(s string) error {
return nil
}
func (o *optionalBoolFlag) IsBoolFlag() bool {
func (flag *optionalBoolFlag) IsBoolFlag() bool {
return true
}
type stringSliceFlag []string
func (flag *stringSliceFlag) String() string {
return strings.Join(*flag, ";")
}
func (flag *stringSliceFlag) Set(s string) error {
*flag = append(*flag, s)
return nil
}

84
cmd/relay/server.go Normal file
View file

@ -0,0 +1,84 @@
package main
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"io"
"log"
"net"
)
func server(args []string) {
var config cmdServerConfig
err := loadConfig(&config, args)
if err != nil {
log.Fatalln("Unable to load configuration:", err)
}
if err := config.Check(); err != nil {
log.Fatalln("Configuration error:", err.Error())
}
log.Printf("Configuration loaded: %+v\n", config)
// Load cert
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
log.Fatalln("Unable to load the certificate:", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{TLSAppProtocol},
MinVersion: tls.VersionTLS13,
}
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
},
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
// No authentication logic in relay, just log username and speed
log.Printf("Client %s connected, negotiated speed in Mbps: Up %d / Down %d\n",
addr.String(), sSend/mbpsToBps, sRecv/mbpsToBps)
return core.AuthSuccess, ""
},
func(addr net.Addr, username string, err error) {
log.Printf("Client %s (%s) disconnected: %s\n", addr.String(), username, err.Error())
},
func(addr net.Addr, username string, id int, isUDP bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
log.Printf("Client %s (%s) opened stream ID %d\n", addr.String(), username, id)
if isUDP {
return core.ConnBlocked, "unsupported", nil
}
conn, err := net.Dial("tcp", config.RemoteAddr)
if err != nil {
log.Printf("TCP error when connecting to %s: %s", config.RemoteAddr, err.Error())
return core.ConnFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
},
func(addr net.Addr, username string, id int, isUDP bool, reqAddr string, err error) {
log.Printf("Client %s (%s) closed stream ID %d: %s", addr.String(), username, id, err.Error())
},
)
if err != nil {
log.Fatalln("Server initialization failed:", err)
}
defer server.Close()
log.Println("The server is now up and running :)")
log.Fatalln("Server error:", server.Serve())
}

175
internal/core/client.go Normal file
View file

@ -0,0 +1,175 @@
package core
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"io"
"net"
"sync"
"sync/atomic"
)
var (
ErrClosed = errors.New("client closed")
)
type Client struct {
inboundBytes, outboundBytes uint64 // atomic
reconnectMutex sync.Mutex
closed bool
quicSession quic.Session
serverAddr string
username, password string
tlsConfig *tls.Config
quicConfig *quic.Config
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
}
func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory) (*Client, error) {
c := &Client{
serverAddr: serverAddr,
username: username,
password: password,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
sendBPS: sendBPS,
recvBPS: recvBPS,
congestionFactory: congestionFactory,
}
if err := c.connectToServer(); err != nil {
return nil, err
}
return c, nil
}
func (c *Client) Dial(udp bool, addr string) (io.ReadWriteCloser, error) {
stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
req := &ClientConnectRequest{Address: addr}
if udp {
req.Type = ConnectionType_UDP
} else {
req.Type = ConnectionType_TCP
}
err = writeClientConnectRequest(stream, req)
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
resp, err := readServerConnectResponse(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
if resp.Result != ConnectResult_CONN_SUCCESS {
_ = stream.Close()
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
resp.Result.String(), resp.Message)
}
if udp {
return &utils.PacketReadWriteCloser{Orig: stream}, nil
} else {
return stream, nil
}
}
func (c *Client) Stats() (uint64, uint64) {
return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes)
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic")
c.closed = true
return err
}
func (c *Client) connectToServer() error {
qs, err := quic.DialAddr(c.serverAddr, c.tlsConfig, c.quicConfig)
if err != nil {
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
return err
}
result, msg, err := c.handleControlStream(qs, ctlStream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return err
}
if result != AuthResult_AUTH_SUCCESS {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
}
// All good
c.quicSession = qs
return nil
}
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthResult, string, error) {
err := writeClientAuthRequest(stream, &ClientAuthRequest{
Credential: &Credential{
Username: c.username,
Password: c.password,
},
Speed: &Speed{
SendBps: c.sendBPS,
ReceiveBps: c.recvBPS,
},
})
if err != nil {
return 0, "", err
}
// Response
resp, err := readServerAuthResponse(stream)
if err != nil {
return 0, "", err
}
// Set the congestion accordingly
if resp.Result == AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps))
}
return resp.Result, resp.Message, nil
}
func (c *Client) openStreamWithReconnect() (quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, ErrClosed
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return stream, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(); err != nil {
// Still error, oops
return nil, err
}
// We are not going to try again even if it still fails the second time
return c.quicSession.OpenStream()
}

103
internal/core/control.go Normal file
View file

@ -0,0 +1,103 @@
package core
import (
"encoding/binary"
"github.com/golang/protobuf/proto"
"io"
)
const (
closeErrorCodeGeneric = 0
closeErrorCodeProtocolFailure = 1
)
func readDataBlock(r io.Reader) ([]byte, error) {
var sz uint32
if err := binary.Read(r, controlProtocolEndian, &sz); err != nil {
return nil, err
}
buf := make([]byte, sz)
_, err := io.ReadFull(r, buf)
return buf, err
}
func writeDataBlock(w io.Writer, data []byte) error {
sz := uint32(len(data))
if err := binary.Write(w, controlProtocolEndian, &sz); err != nil {
return err
}
_, err := w.Write(data)
return err
}
func readClientAuthRequest(r io.Reader) (*ClientAuthRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req ClientAuthRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readServerAuthResponse(r io.Reader) (*ServerAuthResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp ServerAuthResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readClientConnectRequest(r io.Reader) (*ClientConnectRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req ClientConnectRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readServerConnectResponse(r io.Reader) (*ServerConnectResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp ServerConnectResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerConnectResponse(w io.Writer, resp *ServerConnectResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}

439
internal/core/control.pb.go Normal file
View file

@ -0,0 +1,439 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: control.proto
package core
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type AuthResult int32
const (
AuthResult_AUTH_SUCCESS AuthResult = 0
AuthResult_AUTH_INVALID_CRED AuthResult = 1
AuthResult_AUTH_INTERNAL_ERROR AuthResult = 2
)
var AuthResult_name = map[int32]string{
0: "AUTH_SUCCESS",
1: "AUTH_INVALID_CRED",
2: "AUTH_INTERNAL_ERROR",
}
var AuthResult_value = map[string]int32{
"AUTH_SUCCESS": 0,
"AUTH_INVALID_CRED": 1,
"AUTH_INTERNAL_ERROR": 2,
}
func (x AuthResult) String() string {
return proto.EnumName(AuthResult_name, int32(x))
}
func (AuthResult) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{0}
}
type ConnectionType int32
const (
ConnectionType_TCP ConnectionType = 0
ConnectionType_UDP ConnectionType = 1
)
var ConnectionType_name = map[int32]string{
0: "TCP",
1: "UDP",
}
var ConnectionType_value = map[string]int32{
"TCP": 0,
"UDP": 1,
}
func (x ConnectionType) String() string {
return proto.EnumName(ConnectionType_name, int32(x))
}
func (ConnectionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{1}
}
type ConnectResult int32
const (
ConnectResult_CONN_SUCCESS ConnectResult = 0
ConnectResult_CONN_FAILED ConnectResult = 1
ConnectResult_CONN_BLOCKED ConnectResult = 2
)
var ConnectResult_name = map[int32]string{
0: "CONN_SUCCESS",
1: "CONN_FAILED",
2: "CONN_BLOCKED",
}
var ConnectResult_value = map[string]int32{
"CONN_SUCCESS": 0,
"CONN_FAILED": 1,
"CONN_BLOCKED": 2,
}
func (x ConnectResult) String() string {
return proto.EnumName(ConnectResult_name, int32(x))
}
func (ConnectResult) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{2}
}
type Speed struct {
SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"`
ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Speed) Reset() { *m = Speed{} }
func (m *Speed) String() string { return proto.CompactTextString(m) }
func (*Speed) ProtoMessage() {}
func (*Speed) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{0}
}
func (m *Speed) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Speed.Unmarshal(m, b)
}
func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Speed.Marshal(b, m, deterministic)
}
func (m *Speed) XXX_Merge(src proto.Message) {
xxx_messageInfo_Speed.Merge(m, src)
}
func (m *Speed) XXX_Size() int {
return xxx_messageInfo_Speed.Size(m)
}
func (m *Speed) XXX_DiscardUnknown() {
xxx_messageInfo_Speed.DiscardUnknown(m)
}
var xxx_messageInfo_Speed proto.InternalMessageInfo
func (m *Speed) GetSendBps() uint64 {
if m != nil {
return m.SendBps
}
return 0
}
func (m *Speed) GetReceiveBps() uint64 {
if m != nil {
return m.ReceiveBps
}
return 0
}
type Credential struct {
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Credential) Reset() { *m = Credential{} }
func (m *Credential) String() string { return proto.CompactTextString(m) }
func (*Credential) ProtoMessage() {}
func (*Credential) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{1}
}
func (m *Credential) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Credential.Unmarshal(m, b)
}
func (m *Credential) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Credential.Marshal(b, m, deterministic)
}
func (m *Credential) XXX_Merge(src proto.Message) {
xxx_messageInfo_Credential.Merge(m, src)
}
func (m *Credential) XXX_Size() int {
return xxx_messageInfo_Credential.Size(m)
}
func (m *Credential) XXX_DiscardUnknown() {
xxx_messageInfo_Credential.DiscardUnknown(m)
}
var xxx_messageInfo_Credential proto.InternalMessageInfo
func (m *Credential) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *Credential) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type ClientAuthRequest struct {
Credential *Credential `protobuf:"bytes,1,opt,name=credential,proto3" json:"credential,omitempty"`
Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientAuthRequest) Reset() { *m = ClientAuthRequest{} }
func (m *ClientAuthRequest) String() string { return proto.CompactTextString(m) }
func (*ClientAuthRequest) ProtoMessage() {}
func (*ClientAuthRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{2}
}
func (m *ClientAuthRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientAuthRequest.Unmarshal(m, b)
}
func (m *ClientAuthRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientAuthRequest.Marshal(b, m, deterministic)
}
func (m *ClientAuthRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientAuthRequest.Merge(m, src)
}
func (m *ClientAuthRequest) XXX_Size() int {
return xxx_messageInfo_ClientAuthRequest.Size(m)
}
func (m *ClientAuthRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ClientAuthRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ClientAuthRequest proto.InternalMessageInfo
func (m *ClientAuthRequest) GetCredential() *Credential {
if m != nil {
return m.Credential
}
return nil
}
func (m *ClientAuthRequest) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
type ServerAuthResponse struct {
Result AuthResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.AuthResult" json:"result,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
Speed *Speed `protobuf:"bytes,3,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerAuthResponse) Reset() { *m = ServerAuthResponse{} }
func (m *ServerAuthResponse) String() string { return proto.CompactTextString(m) }
func (*ServerAuthResponse) ProtoMessage() {}
func (*ServerAuthResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{3}
}
func (m *ServerAuthResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerAuthResponse.Unmarshal(m, b)
}
func (m *ServerAuthResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerAuthResponse.Marshal(b, m, deterministic)
}
func (m *ServerAuthResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerAuthResponse.Merge(m, src)
}
func (m *ServerAuthResponse) XXX_Size() int {
return xxx_messageInfo_ServerAuthResponse.Size(m)
}
func (m *ServerAuthResponse) XXX_DiscardUnknown() {
xxx_messageInfo_ServerAuthResponse.DiscardUnknown(m)
}
var xxx_messageInfo_ServerAuthResponse proto.InternalMessageInfo
func (m *ServerAuthResponse) GetResult() AuthResult {
if m != nil {
return m.Result
}
return AuthResult_AUTH_SUCCESS
}
func (m *ServerAuthResponse) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func (m *ServerAuthResponse) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
type ClientConnectRequest struct {
Type ConnectionType `protobuf:"varint,1,opt,name=type,proto3,enum=core.ConnectionType" json:"type,omitempty"`
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientConnectRequest) Reset() { *m = ClientConnectRequest{} }
func (m *ClientConnectRequest) String() string { return proto.CompactTextString(m) }
func (*ClientConnectRequest) ProtoMessage() {}
func (*ClientConnectRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{4}
}
func (m *ClientConnectRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientConnectRequest.Unmarshal(m, b)
}
func (m *ClientConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientConnectRequest.Marshal(b, m, deterministic)
}
func (m *ClientConnectRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientConnectRequest.Merge(m, src)
}
func (m *ClientConnectRequest) XXX_Size() int {
return xxx_messageInfo_ClientConnectRequest.Size(m)
}
func (m *ClientConnectRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ClientConnectRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ClientConnectRequest proto.InternalMessageInfo
func (m *ClientConnectRequest) GetType() ConnectionType {
if m != nil {
return m.Type
}
return ConnectionType_TCP
}
func (m *ClientConnectRequest) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
type ServerConnectResponse struct {
Result ConnectResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.ConnectResult" json:"result,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerConnectResponse) Reset() { *m = ServerConnectResponse{} }
func (m *ServerConnectResponse) String() string { return proto.CompactTextString(m) }
func (*ServerConnectResponse) ProtoMessage() {}
func (*ServerConnectResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{5}
}
func (m *ServerConnectResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerConnectResponse.Unmarshal(m, b)
}
func (m *ServerConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerConnectResponse.Marshal(b, m, deterministic)
}
func (m *ServerConnectResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerConnectResponse.Merge(m, src)
}
func (m *ServerConnectResponse) XXX_Size() int {
return xxx_messageInfo_ServerConnectResponse.Size(m)
}
func (m *ServerConnectResponse) XXX_DiscardUnknown() {
xxx_messageInfo_ServerConnectResponse.DiscardUnknown(m)
}
var xxx_messageInfo_ServerConnectResponse proto.InternalMessageInfo
func (m *ServerConnectResponse) GetResult() ConnectResult {
if m != nil {
return m.Result
}
return ConnectResult_CONN_SUCCESS
}
func (m *ServerConnectResponse) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func init() {
proto.RegisterEnum("core.AuthResult", AuthResult_name, AuthResult_value)
proto.RegisterEnum("core.ConnectionType", ConnectionType_name, ConnectionType_value)
proto.RegisterEnum("core.ConnectResult", ConnectResult_name, ConnectResult_value)
proto.RegisterType((*Speed)(nil), "core.Speed")
proto.RegisterType((*Credential)(nil), "core.Credential")
proto.RegisterType((*ClientAuthRequest)(nil), "core.ClientAuthRequest")
proto.RegisterType((*ServerAuthResponse)(nil), "core.ServerAuthResponse")
proto.RegisterType((*ClientConnectRequest)(nil), "core.ClientConnectRequest")
proto.RegisterType((*ServerConnectResponse)(nil), "core.ServerConnectResponse")
}
func init() {
proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d)
}
var fileDescriptor_0c5120591600887d = []byte{
// 431 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xd1, 0x6e, 0xd3, 0x30,
0x14, 0x86, 0xd7, 0xae, 0x5b, 0xb7, 0x13, 0x36, 0x32, 0x6f, 0x13, 0x83, 0x1b, 0x20, 0x57, 0x55,
0x91, 0x2a, 0x34, 0x9e, 0x20, 0x75, 0x82, 0xa8, 0xa8, 0xd2, 0xc9, 0x69, 0xb9, 0xe0, 0x82, 0x2a,
0x4b, 0x8e, 0x58, 0xa5, 0xcc, 0x36, 0xb6, 0x33, 0x34, 0xf1, 0xf2, 0x28, 0x8e, 0x93, 0xae, 0x48,
0x48, 0xbb, 0xeb, 0x39, 0xe7, 0xd7, 0xff, 0xf9, 0xab, 0x02, 0x27, 0xb9, 0xe0, 0x46, 0x89, 0x72,
0x22, 0x95, 0x30, 0x82, 0x0c, 0x72, 0xa1, 0x30, 0xa0, 0x70, 0x90, 0x4a, 0xc4, 0x82, 0xbc, 0x86,
0x23, 0x8d, 0xbc, 0x58, 0xdf, 0x4a, 0x7d, 0xd5, 0x7b, 0xd7, 0x1b, 0x0d, 0xd8, 0xb0, 0x9e, 0xa7,
0x52, 0x93, 0xb7, 0xe0, 0x29, 0xcc, 0x71, 0xf3, 0x80, 0xf6, 0xda, 0xb7, 0x57, 0x70, 0xab, 0xa9,
0xd4, 0x41, 0x04, 0x40, 0x15, 0x16, 0xc8, 0xcd, 0x26, 0x2b, 0xc9, 0x1b, 0x38, 0xaa, 0x34, 0x2a,
0x9e, 0xdd, 0xa3, 0x6d, 0x3a, 0x66, 0xdd, 0x5c, 0xdf, 0x64, 0xa6, 0xf5, 0x6f, 0xa1, 0x0a, 0xdb,
0x73, 0xcc, 0xba, 0x39, 0xb8, 0x83, 0x33, 0x5a, 0x6e, 0x90, 0x9b, 0xb0, 0x32, 0x77, 0x0c, 0x7f,
0x55, 0xa8, 0x0d, 0xf9, 0x08, 0x90, 0x77, 0xd5, 0xb6, 0xce, 0xbb, 0xf6, 0x27, 0xf5, 0xd3, 0x27,
0x5b, 0x24, 0x7b, 0x92, 0x21, 0xef, 0xe1, 0x40, 0xd7, 0x46, 0xb6, 0xdf, 0xbb, 0xf6, 0x9a, 0xb0,
0x95, 0x64, 0xcd, 0x25, 0xf8, 0x03, 0x24, 0x45, 0xf5, 0x80, 0xaa, 0x21, 0x69, 0x29, 0xb8, 0x46,
0x32, 0x82, 0x43, 0x85, 0xba, 0x2a, 0x8d, 0xc5, 0x9c, 0xb6, 0x18, 0x97, 0xa9, 0x4a, 0xc3, 0xdc,
0x9d, 0x5c, 0xc1, 0xf0, 0x1e, 0xb5, 0xce, 0x7e, 0xa2, 0x93, 0x68, 0xc7, 0x2d, 0x7c, 0xff, 0xbf,
0xf0, 0xef, 0x70, 0xd1, 0x68, 0x52, 0xc1, 0x39, 0xe6, 0xa6, 0x35, 0x1d, 0xc1, 0xc0, 0x3c, 0x4a,
0x74, 0xf0, 0x0b, 0xe7, 0xd8, 0x64, 0x36, 0x82, 0x2f, 0x1f, 0x25, 0x32, 0x9b, 0xa8, 0xf1, 0x59,
0x51, 0x28, 0xd4, 0xba, 0xc5, 0xbb, 0x31, 0xf8, 0x01, 0x97, 0x8d, 0x58, 0xd7, 0xed, 0xdc, 0x3e,
0xfc, 0xe3, 0x76, 0xbe, 0x53, 0xff, 0x5c, 0xbd, 0x71, 0x02, 0xb0, 0xfd, 0x3b, 0x88, 0x0f, 0x2f,
0xc2, 0xd5, 0xf2, 0xcb, 0x3a, 0x5d, 0x51, 0x1a, 0xa7, 0xa9, 0xbf, 0x47, 0x2e, 0xe1, 0xcc, 0x6e,
0x66, 0xc9, 0xb7, 0x70, 0x3e, 0x8b, 0xd6, 0x94, 0xc5, 0x91, 0xdf, 0x23, 0xaf, 0xe0, 0xdc, 0xad,
0x97, 0x31, 0x4b, 0xc2, 0xf9, 0x3a, 0x66, 0x6c, 0xc1, 0xfc, 0xfe, 0x38, 0x80, 0xd3, 0x5d, 0x43,
0x32, 0x84, 0xfd, 0x25, 0xbd, 0xf1, 0xf7, 0xea, 0x1f, 0xab, 0xe8, 0xc6, 0xef, 0x8d, 0x23, 0x38,
0xd9, 0x79, 0x66, 0x8d, 0xa5, 0x8b, 0x24, 0x79, 0x82, 0x7d, 0x09, 0x9e, 0xdd, 0x7c, 0x0e, 0x67,
0x73, 0x0b, 0x6c, 0x23, 0xd3, 0xf9, 0x82, 0x7e, 0x8d, 0x23, 0xbf, 0x7f, 0x7b, 0x68, 0x3f, 0xfa,
0x4f, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0x65, 0xfc, 0xeb, 0x5c, 0x05, 0x03, 0x00, 0x00,
}

View file

@ -0,0 +1,50 @@
syntax = "proto3";
package core;
message Speed {
uint64 send_bps = 1;
uint64 receive_bps = 2;
}
message Credential {
string username = 1;
string password = 2;
}
enum AuthResult {
AUTH_SUCCESS = 0;
AUTH_INVALID_CRED = 1;
AUTH_INTERNAL_ERROR = 2;
}
message ClientAuthRequest {
Credential credential = 1;
Speed speed = 2;
}
message ServerAuthResponse {
AuthResult result = 1;
string message = 2;
Speed speed = 3;
}
enum ConnectionType {
TCP = 0;
UDP = 1;
}
enum ConnectResult {
CONN_SUCCESS = 0;
CONN_FAILED = 1;
CONN_BLOCKED = 2;
}
message ClientConnectRequest {
ConnectionType type = 1;
string address = 2;
}
message ServerConnectResponse {
ConnectResult result = 1;
string message = 2;
}

View file

@ -1,3 +1,3 @@
package forwarder
package core
//go:generate protoc --go_out=. control.proto

199
internal/core/server.go Normal file
View file

@ -0,0 +1,199 @@
package core
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"io"
"net"
"sync/atomic"
)
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc func(addr net.Addr, username string, err error)
type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error)
type Server struct {
inboundBytes, outboundBytes uint64 // atomic
listener quic.Listener
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
clientAuthFunc ClientAuthFunc
clientDisconnectedFunc ClientDisconnectedFunc
handleRequestFunc HandleRequestFunc
requestClosedFunc RequestClosedFunc
}
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
clientAuthFunc ClientAuthFunc,
clientDisconnectedFunc ClientDisconnectedFunc,
handleRequestFunc HandleRequestFunc,
requestClosedFunc RequestClosedFunc) (*Server, error) {
listener, err := quic.ListenAddr(addr, tlsConfig, quicConfig)
if err != nil {
return nil, err
}
s := &Server{
listener: listener,
sendBPS: sendBPS,
recvBPS: recvBPS,
congestionFactory: congestionFactory,
clientAuthFunc: clientAuthFunc,
clientDisconnectedFunc: clientDisconnectedFunc,
handleRequestFunc: handleRequestFunc,
requestClosedFunc: requestClosedFunc,
}
return s, nil
}
func (s *Server) Serve() error {
for {
cs, err := s.listener.Accept(context.Background())
if err != nil {
return err
}
go s.handleClient(cs)
}
}
func (s *Server) Stats() (uint64, uint64) {
return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes)
}
func (s *Server) Close() error {
return s.listener.Close()
}
func (s *Server) handleClient(cs quic.Session) {
// Expect the client to create a control stream to send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
return
}
// Handle the control stream
username, ok, err := s.handleControlStream(cs, ctlStream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return
}
if !ok {
_ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure")
return
}
// Start accepting streams
var closeErr error
for {
stream, err := cs.AcceptStream(context.Background())
if err != nil {
closeErr = err
break
}
go s.handleStream(cs.RemoteAddr(), username, stream)
}
s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr)
_ = cs.CloseWithError(closeErrorCodeGeneric, "generic")
}
// Auth & negotiate speed
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) {
req, err := readClientAuthRequest(stream)
if err != nil {
return "", false, err
}
// Speed
if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 {
return "", false, errors.New("incorrect speed provided by the client")
}
serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS {
serverReceiveBPS = s.recvBPS
}
// Auth
if req.Credential == nil {
return "", false, errors.New("incorrect credential provided by the client")
}
authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password,
serverSendBPS, serverReceiveBPS)
// Response
err = writeServerAuthResponse(stream, &ServerAuthResponse{
Result: authResult,
Message: msg,
Speed: &Speed{
SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS,
},
})
if err != nil {
return "", false, err
}
// Set the congestion accordingly
if authResult == AuthResult_AUTH_SUCCESS && s.congestionFactory != nil {
cs.SetCongestion(s.congestionFactory(serverSendBPS))
}
return req.Credential.Username, authResult == AuthResult_AUTH_SUCCESS, nil
}
func (s *Server) handleStream(addr net.Addr, username string, stream quic.Stream) {
defer stream.Close()
// Read request
req, err := readClientConnectRequest(stream)
if err != nil {
return
}
// Create connection with the handler
result, msg, conn := s.handleRequestFunc(addr, username, int(stream.StreamID()), req.Type, req.Address)
defer func() {
if conn != nil {
_ = conn.Close()
}
}()
// Send response
err = writeServerConnectResponse(stream, &ServerConnectResponse{
Result: result,
Message: msg,
})
if err != nil {
s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err)
return
}
if result != ConnectResult_CONN_SUCCESS {
s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address,
fmt.Errorf("handler returned an unsuccessful state %s (msg: %s)", result.String(), msg))
return
}
switch req.Type {
case ConnectionType_TCP:
err = s.pipePair(stream, conn)
case ConnectionType_UDP:
err = s.pipePair(&utils.PacketReadWriteCloser{Orig: stream}, conn)
default:
err = fmt.Errorf("unsupported connection type %s", req.Type.String())
}
s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err)
}
func (s *Server) pipePair(rw1, rw2 io.ReadWriter) error {
// Pipes
errChan := make(chan error, 2)
go func() {
errChan <- utils.Pipe(rw2, rw1, &s.outboundBytes)
}()
go func() {
errChan <- utils.Pipe(rw1, rw2, &s.inboundBytes)
}()
// We only need the first error
return <-errChan
}

13
internal/core/types.go Normal file
View file

@ -0,0 +1,13 @@
package core
import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/congestion"
"time"
)
const controlStreamTimeout = 10 * time.Second
var controlProtocolEndian = binary.BigEndian
type CongestionFactory func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos

View file

@ -1,203 +0,0 @@
package forwarder
import (
"context"
"crypto/tls"
"errors"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"net"
"sync"
"sync/atomic"
)
type QUICClient struct {
inboundBytes, outboundBytes uint64 // atomic
reconnectMutex sync.Mutex
quicSession quic.Session
listener net.Listener
remoteAddr string
name string
tlsConfig *tls.Config
sendBPS, recvBPS uint64
recvWindowConn, recvWindow uint64
closed bool
newCongestion CongestionFactory
onServerConnected ServerConnectedCallback
onServerError ServerErrorCallback
onNewTCPConnection NewTCPConnectionCallback
onTCPConnectionClosed TCPConnectionClosedCallback
}
func NewQUICClient(addr string, remoteAddr string, name string, tlsConfig *tls.Config,
sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindow uint64,
newCongestion CongestionFactory,
onServerConnected ServerConnectedCallback,
onServerError ServerErrorCallback,
onNewTCPConnection NewTCPConnectionCallback,
onTCPConnectionClosed TCPConnectionClosedCallback) (*QUICClient, error) {
// Local TCP listener
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
c := &QUICClient{
listener: listener,
remoteAddr: remoteAddr,
name: name,
tlsConfig: tlsConfig,
sendBPS: sendBPS,
recvBPS: recvBPS,
recvWindowConn: recvWindowConn,
recvWindow: recvWindow,
newCongestion: newCongestion,
onServerConnected: onServerConnected,
onServerError: onServerError,
onNewTCPConnection: onNewTCPConnection,
onTCPConnectionClosed: onTCPConnectionClosed,
}
if err := c.connectToServer(); err != nil {
_ = c.listener.Close()
return nil, err
}
go c.acceptLoop()
return c, nil
}
func (c *QUICClient) Close() error {
err1 := c.listener.Close()
c.reconnectMutex.Lock()
err2 := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic")
c.closed = true
c.reconnectMutex.Unlock()
if err1 != nil {
return err1
}
return err2
}
func (c *QUICClient) Stats() (string, uint64, uint64) {
return c.remoteAddr, atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes)
}
func (c *QUICClient) acceptLoop() {
for {
conn, err := c.listener.Accept()
if err != nil {
break
}
go c.handleConn(conn)
}
}
func (c *QUICClient) connectToServer() error {
qs, err := quic.DialAddr(c.remoteAddr, c.tlsConfig, &quic.Config{
MaxReceiveStreamFlowControlWindow: c.recvWindowConn,
MaxReceiveConnectionFlowControlWindow: c.recvWindow,
KeepAlive: true,
})
if err != nil {
c.onServerError(err)
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
c.onServerError(err)
return err
}
banner, cSendBPS, cRecvBPS, err := handleControlStream(qs, ctlStream, c.name, c.sendBPS, c.recvBPS, c.newCongestion)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
c.onServerError(err)
return err
}
// All good
c.quicSession = qs
c.onServerConnected(qs.RemoteAddr(), banner, cSendBPS, cRecvBPS)
return nil
}
func (c *QUICClient) openStreamWithReconnect() (quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, errors.New("client closed")
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return stream, nil
}
// Something is wrong
c.onServerError(err)
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(); err != nil {
// Still error, oops
return nil, err
}
// We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream()
if err != nil {
c.onServerError(err)
}
return stream, err
}
// Negotiate speed, return banner, send & receive speed
func handleControlStream(qs quic.Session, stream quic.Stream, name string, sendBPS uint64, recvBPS uint64,
newCongestion CongestionFactory) (string, uint64, uint64, error) {
err := writeClientSpeedRequest(stream, &ClientSpeedRequest{
Name: name,
Speed: &Speed{
SendBps: sendBPS,
ReceiveBps: recvBPS,
},
})
if err != nil {
return "", 0, 0, err
}
// Response
resp, err := readServerSpeedResponse(stream)
if err != nil {
return "", 0, 0, err
}
// Set the congestion accordingly
if newCongestion != nil {
qs.SetCongestion(newCongestion(resp.Speed.ReceiveBps))
}
return resp.Banner, resp.Speed.ReceiveBps, resp.Speed.SendBps, nil
}
func (c *QUICClient) handleConn(conn net.Conn) {
c.onNewTCPConnection(conn.RemoteAddr())
defer conn.Close()
stream, err := c.openStreamWithReconnect()
if err != nil {
c.onTCPConnectionClosed(conn.RemoteAddr(), err)
return
}
defer stream.Close()
// Pipes
errChan := make(chan error, 2)
go func() {
// TCP to QUIC
errChan <- utils.Pipe(conn, stream, &c.outboundBytes)
}()
go func() {
// QUIC to TCP
errChan <- utils.Pipe(stream, conn, &c.inboundBytes)
}()
// We only need the first error
err = <-errChan
c.onTCPConnectionClosed(conn.RemoteAddr(), err)
}

View file

@ -1,67 +0,0 @@
package forwarder
import (
"encoding/binary"
"github.com/golang/protobuf/proto"
"io"
)
const (
closeErrorCodeGeneric = 0
closeErrorCodeProtocolFailure = 1
)
func readDataBlock(r io.Reader) ([]byte, error) {
var sz uint32
if err := binary.Read(r, controlProtocolEndian, &sz); err != nil {
return nil, err
}
buf := make([]byte, sz)
_, err := io.ReadFull(r, buf)
return buf, err
}
func writeDataBlock(w io.Writer, data []byte) error {
sz := uint32(len(data))
if err := binary.Write(w, controlProtocolEndian, &sz); err != nil {
return err
}
_, err := w.Write(data)
return err
}
func readClientSpeedRequest(r io.Reader) (*ClientSpeedRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req ClientSpeedRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientSpeedRequest(w io.Writer, req *ClientSpeedRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readServerSpeedResponse(r io.Reader) (*ServerSpeedResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp ServerSpeedResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerSpeedResponse(w io.Writer, resp *ServerSpeedResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}

View file

@ -1,206 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: control.proto
package forwarder
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Speed struct {
SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"`
ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Speed) Reset() { *m = Speed{} }
func (m *Speed) String() string { return proto.CompactTextString(m) }
func (*Speed) ProtoMessage() {}
func (*Speed) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{0}
}
func (m *Speed) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Speed.Unmarshal(m, b)
}
func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Speed.Marshal(b, m, deterministic)
}
func (m *Speed) XXX_Merge(src proto.Message) {
xxx_messageInfo_Speed.Merge(m, src)
}
func (m *Speed) XXX_Size() int {
return xxx_messageInfo_Speed.Size(m)
}
func (m *Speed) XXX_DiscardUnknown() {
xxx_messageInfo_Speed.DiscardUnknown(m)
}
var xxx_messageInfo_Speed proto.InternalMessageInfo
func (m *Speed) GetSendBps() uint64 {
if m != nil {
return m.SendBps
}
return 0
}
func (m *Speed) GetReceiveBps() uint64 {
if m != nil {
return m.ReceiveBps
}
return 0
}
type ClientSpeedRequest struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientSpeedRequest) Reset() { *m = ClientSpeedRequest{} }
func (m *ClientSpeedRequest) String() string { return proto.CompactTextString(m) }
func (*ClientSpeedRequest) ProtoMessage() {}
func (*ClientSpeedRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{1}
}
func (m *ClientSpeedRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientSpeedRequest.Unmarshal(m, b)
}
func (m *ClientSpeedRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientSpeedRequest.Marshal(b, m, deterministic)
}
func (m *ClientSpeedRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientSpeedRequest.Merge(m, src)
}
func (m *ClientSpeedRequest) XXX_Size() int {
return xxx_messageInfo_ClientSpeedRequest.Size(m)
}
func (m *ClientSpeedRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ClientSpeedRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ClientSpeedRequest proto.InternalMessageInfo
func (m *ClientSpeedRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *ClientSpeedRequest) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
type ServerSpeedResponse struct {
Banner string `protobuf:"bytes,1,opt,name=banner,proto3" json:"banner,omitempty"`
Limited bool `protobuf:"varint,2,opt,name=limited,proto3" json:"limited,omitempty"`
Limit *Speed `protobuf:"bytes,3,opt,name=limit,proto3" json:"limit,omitempty"`
Speed *Speed `protobuf:"bytes,4,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerSpeedResponse) Reset() { *m = ServerSpeedResponse{} }
func (m *ServerSpeedResponse) String() string { return proto.CompactTextString(m) }
func (*ServerSpeedResponse) ProtoMessage() {}
func (*ServerSpeedResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{2}
}
func (m *ServerSpeedResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerSpeedResponse.Unmarshal(m, b)
}
func (m *ServerSpeedResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerSpeedResponse.Marshal(b, m, deterministic)
}
func (m *ServerSpeedResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerSpeedResponse.Merge(m, src)
}
func (m *ServerSpeedResponse) XXX_Size() int {
return xxx_messageInfo_ServerSpeedResponse.Size(m)
}
func (m *ServerSpeedResponse) XXX_DiscardUnknown() {
xxx_messageInfo_ServerSpeedResponse.DiscardUnknown(m)
}
var xxx_messageInfo_ServerSpeedResponse proto.InternalMessageInfo
func (m *ServerSpeedResponse) GetBanner() string {
if m != nil {
return m.Banner
}
return ""
}
func (m *ServerSpeedResponse) GetLimited() bool {
if m != nil {
return m.Limited
}
return false
}
func (m *ServerSpeedResponse) GetLimit() *Speed {
if m != nil {
return m.Limit
}
return nil
}
func (m *ServerSpeedResponse) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
func init() {
proto.RegisterType((*Speed)(nil), "forwarder.Speed")
proto.RegisterType((*ClientSpeedRequest)(nil), "forwarder.ClientSpeedRequest")
proto.RegisterType((*ServerSpeedResponse)(nil), "forwarder.ServerSpeedResponse")
}
func init() {
proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d)
}
var fileDescriptor_0c5120591600887d = []byte{
// 220 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x90, 0x4d, 0x4a, 0xc6, 0x30,
0x10, 0x86, 0xa9, 0xf6, 0xfb, 0x9b, 0x0f, 0x41, 0x46, 0x90, 0xba, 0x52, 0xba, 0x10, 0x57, 0x5d,
0xe8, 0x0d, 0xbe, 0x5e, 0x40, 0xd2, 0x03, 0x48, 0x7f, 0x46, 0x08, 0xb4, 0x49, 0x9c, 0x89, 0xf5,
0x28, 0x5e, 0x57, 0x3a, 0x8d, 0xba, 0xd2, 0xdd, 0xbc, 0x79, 0x92, 0x67, 0x5e, 0x02, 0x17, 0xbd,
0x77, 0x91, 0xfd, 0x58, 0x05, 0xf6, 0xd1, 0xe3, 0xe1, 0xd5, 0xf3, 0x47, 0xcb, 0x03, 0x71, 0x59,
0xc3, 0xa6, 0x09, 0x44, 0x03, 0xde, 0xc0, 0x5e, 0xc8, 0x0d, 0x2f, 0x5d, 0x90, 0x22, 0xbb, 0xcb,
0x1e, 0x72, 0xb3, 0x5b, 0xf2, 0x29, 0x08, 0xde, 0xc2, 0x91, 0xa9, 0x27, 0x3b, 0x93, 0xd2, 0x33,
0xa5, 0x90, 0x8e, 0x4e, 0x41, 0xca, 0x67, 0xc0, 0x7a, 0xb4, 0xe4, 0xa2, 0xaa, 0x0c, 0xbd, 0xbd,
0x93, 0x44, 0x44, 0xc8, 0x5d, 0x3b, 0x91, 0xda, 0x0e, 0x46, 0x67, 0xbc, 0x87, 0x8d, 0x2c, 0x77,
0x54, 0x72, 0x7c, 0xbc, 0xac, 0x7e, 0x9a, 0x54, 0xeb, 0xdb, 0x15, 0x97, 0x9f, 0x19, 0x5c, 0x35,
0xc4, 0x33, 0x71, 0x52, 0x4a, 0xf0, 0x4e, 0x08, 0xaf, 0x61, 0xdb, 0xb5, 0xce, 0x11, 0x27, 0x6b,
0x4a, 0x58, 0xc0, 0x6e, 0xb4, 0x93, 0x8d, 0xc9, 0xbc, 0x37, 0xdf, 0x71, 0xd9, 0xa8, 0x63, 0x71,
0xfe, 0xd7, 0x46, 0xc5, 0xbf, 0xcd, 0xf2, 0x7f, 0x9b, 0x75, 0x5b, 0xfd, 0xc2, 0xa7, 0xaf, 0x00,
0x00, 0x00, 0xff, 0xff, 0xb2, 0x10, 0x5a, 0xf2, 0x53, 0x01, 0x00, 0x00,
}

View file

@ -1,19 +0,0 @@
syntax = "proto3";
package forwarder;
message Speed {
uint64 send_bps = 1;
uint64 receive_bps = 2;
}
message ClientSpeedRequest {
string name = 1;
Speed speed = 2;
}
message ServerSpeedResponse {
string banner = 1;
bool limited = 2;
Speed limit = 3;
Speed speed = 4;
}

View file

@ -1,10 +0,0 @@
package forwarder
import (
"encoding/binary"
"time"
)
const controlStreamTimeout = 10 * time.Second
var controlProtocolEndian = binary.BigEndian

View file

@ -1,176 +0,0 @@
package forwarder
import (
"context"
"crypto/tls"
"errors"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"net"
"sync/atomic"
)
type QUICServer struct {
inboundBytes, outboundBytes uint64 // atomic
listener quic.Listener
remoteAddr string
banner string
sendBPS, recvBPS uint64
newCongestion CongestionFactory
onClientConnected ClientConnectedCallback
onClientDisconnected ClientDisconnectedCallback
onClientNewStream ClientNewStreamCallback
onClientStreamClosed ClientStreamClosedCallback
onTCPError TCPErrorCallback
}
func NewQUICServer(addr string, remoteAddr string, banner string, tlsConfig *tls.Config,
sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindowClients uint64,
clientMaxConn int, newCongestion CongestionFactory,
onClientConnected ClientConnectedCallback,
onClientDisconnected ClientDisconnectedCallback,
onClientNewStream ClientNewStreamCallback,
onClientStreamClosed ClientStreamClosedCallback,
onTCPError TCPErrorCallback) (*QUICServer, error) {
listener, err := quic.ListenAddr(addr, tlsConfig, &quic.Config{
MaxReceiveStreamFlowControlWindow: recvWindowConn,
MaxReceiveConnectionFlowControlWindow: recvWindowClients,
MaxIncomingStreams: clientMaxConn,
KeepAlive: true,
})
if err != nil {
return nil, err
}
s := &QUICServer{
listener: listener,
remoteAddr: remoteAddr,
banner: banner,
sendBPS: sendBPS,
recvBPS: recvBPS,
newCongestion: newCongestion,
onClientConnected: onClientConnected,
onClientDisconnected: onClientDisconnected,
onClientNewStream: onClientNewStream,
onClientStreamClosed: onClientStreamClosed,
onTCPError: onTCPError,
}
go s.acceptLoop()
return s, nil
}
func (s *QUICServer) Close() error {
return s.listener.Close()
}
func (s *QUICServer) Stats() (string, uint64, uint64) {
return s.remoteAddr, atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes)
}
func (s *QUICServer) acceptLoop() {
for {
cs, err := s.listener.Accept(context.Background())
if err != nil {
break
}
go s.handleClient(cs)
}
}
func (s *QUICServer) handleClient(cs quic.Session) {
// Expect the client to create a control stream and send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
return
}
name, sSend, sRecv, err := s.handleControlStream(cs, ctlStream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return
}
// Only after a successful exchange of information do we consider this a valid client
s.onClientConnected(cs.RemoteAddr(), name, sSend, sRecv)
// Start accepting streams to be forwarded
var closeErr error
for {
stream, err := cs.AcceptStream(context.Background())
if err != nil {
closeErr = err
break
}
go s.handleStream(cs.RemoteAddr(), name, stream)
}
s.onClientDisconnected(cs.RemoteAddr(), name, closeErr)
_ = cs.CloseWithError(closeErrorCodeGeneric, "generic")
}
// Negotiate speed & return client name
func (s *QUICServer) handleControlStream(cs quic.Session, stream quic.Stream) (string, uint64, uint64, error) {
req, err := readClientSpeedRequest(stream)
if err != nil {
return "", 0, 0, err
}
if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 {
return "", 0, 0, errors.New("incorrect speed information provided by the client")
}
limited := false
serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
limited = true
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS {
limited = true
serverReceiveBPS = s.recvBPS
}
// Response
err = writeServerSpeedResponse(stream, &ServerSpeedResponse{
Banner: s.banner,
Limited: limited,
Limit: &Speed{
SendBps: s.sendBPS,
ReceiveBps: s.recvBPS,
},
Speed: &Speed{
SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS,
},
})
if err != nil {
return "", 0, 0, err
}
// Set the congestion accordingly
if s.newCongestion != nil {
cs.SetCongestion(s.newCongestion(serverSendBPS))
}
return req.Name, serverSendBPS, serverReceiveBPS, nil
}
func (s *QUICServer) handleStream(addr net.Addr, name string, stream quic.Stream) {
s.onClientNewStream(addr, name, int(stream.StreamID()))
defer stream.Close()
tcpConn, err := net.Dial("tcp", s.remoteAddr)
if err != nil {
s.onTCPError(s.remoteAddr, err)
s.onClientStreamClosed(addr, name, int(stream.StreamID()), err)
return
}
defer tcpConn.Close()
// Pipes
errChan := make(chan error, 2)
go func() {
// TCP to QUIC
errChan <- utils.Pipe(tcpConn, stream, &s.outboundBytes)
}()
go func() {
// QUIC to TCP
errChan <- utils.Pipe(stream, tcpConn, &s.inboundBytes)
}()
// We only need the first error
err = <-errChan
s.onClientStreamClosed(addr, name, int(stream.StreamID()), err)
}

View file

@ -1,21 +0,0 @@
package forwarder
import (
"github.com/lucas-clemente/quic-go/congestion"
"net"
)
type CongestionFactory func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos
// For server
type ClientConnectedCallback func(addr net.Addr, name string, sSend uint64, sRecv uint64)
type ClientDisconnectedCallback func(addr net.Addr, name string, err error)
type ClientNewStreamCallback func(addr net.Addr, name string, id int)
type ClientStreamClosedCallback func(addr net.Addr, name string, id int, err error)
type TCPErrorCallback func(remoteAddr string, err error)
// For client
type ServerConnectedCallback func(addr net.Addr, banner string, cSend uint64, cRecv uint64)
type ServerErrorCallback func(err error)
type NewTCPConnectionCallback func(addr net.Addr)
type TCPConnectionClosedCallback func(addr net.Addr, err error)

View file

@ -0,0 +1,35 @@
package utils
import (
"encoding/binary"
"fmt"
"io"
)
type PacketReadWriteCloser struct {
Orig io.ReadWriteCloser
}
func (rw *PacketReadWriteCloser) Read(p []byte) (n int, err error) {
var sz uint32
if err := binary.Read(rw.Orig, binary.BigEndian, &sz); err != nil {
return 0, err
}
if int(sz) <= len(p) {
return io.ReadFull(rw.Orig, p[:sz])
} else {
return 0, fmt.Errorf("the buffer is too small to hold %d bytes of packet data", sz)
}
}
func (rw *PacketReadWriteCloser) Write(p []byte) (n int, err error) {
sz := uint32(len(p))
if err := binary.Write(rw.Orig, binary.BigEndian, &sz); err != nil {
return 0, err
}
return rw.Orig.Write(p)
}
func (rw *PacketReadWriteCloser) Close() error {
return rw.Orig.Close()
}

View file

@ -5,7 +5,7 @@ import (
"sync/atomic"
)
const pipeBufferSize = 16384
const pipeBufferSize = 65536
func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
buf := make([]byte, pipeBufferSize)
@ -13,7 +13,9 @@ func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
rn, err := src.Read(buf)
if rn > 0 {
wn, err := dst.Write(buf[:rn])
atomic.AddUint64(atomicCounter, uint64(wn))
if atomicCounter != nil {
atomic.AddUint64(atomicCounter, uint64(wn))
}
if err != nil {
return err
}

71
pkg/core/interface.go Normal file
View file

@ -0,0 +1,71 @@
package core
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/core"
"io"
"net"
)
type AuthResult int32
const (
AuthSuccess = AuthResult(iota)
AuthInvalidCred
AuthInternalError
)
type ConnectResult int32
const (
ConnSuccess = ConnectResult(iota)
ConnFailed
ConnBlocked
)
type CongestionFactory core.CongestionFactory
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc core.ClientDisconnectedFunc
type HandleRequestFunc func(addr net.Addr, username string, id int, isUDP bool, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
type RequestClosedFunc func(addr net.Addr, username string, id int, isUDP bool, reqAddr string, err error)
type Server interface {
Serve() error
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
clientAuthFunc ClientAuthFunc,
clientDisconnectedFunc ClientDisconnectedFunc,
handleRequestFunc HandleRequestFunc,
requestClosedFunc RequestClosedFunc) (Server, error) {
return core.NewServer(addr, tlsConfig, quicConfig, sendBPS, recvBPS, core.CongestionFactory(congestionFactory),
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
r, msg := clientAuthFunc(addr, username, password, sSend, sRecv)
return core.AuthResult(r), msg
},
core.ClientDisconnectedFunc(clientDisconnectedFunc),
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
r, msg, conn := handleRequestFunc(addr, username, id, reqType == core.ConnectionType_UDP, reqAddr)
return core.ConnectResult(r), msg, conn
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
requestClosedFunc(addr, username, id, reqType == core.ConnectionType_UDP, reqAddr, err)
})
}
type Client interface {
Dial(udp bool, addr string) (io.ReadWriteCloser, error)
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewClient(serverAddr string, username string, password string,
tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64,
congestionFactory CongestionFactory) (Client, error) {
return core.NewClient(serverAddr, username, password, tlsConfig, quicConfig, sendBPS, recvBPS,
core.CongestionFactory(congestionFactory))
}

View file

@ -1,70 +0,0 @@
package forwarder
import (
"crypto/tls"
"errors"
"github.com/tobyxdd/hysteria/internal/forwarder"
"net"
)
type client struct {
qc *forwarder.QUICClient
}
func NewClient(localAddr string, remoteAddr string, config ClientConfig, callbacks ClientCallbacks) (Client, error) {
// Fix config first
if config.Speed == nil || config.Speed.SendBPS == 0 || config.Speed.ReceiveBPS == 0 {
return nil, errors.New("invalid speed")
}
if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{NextProtos: []string{TLSAppProtocol}}
}
if config.MaxReceiveWindowPerConnection == 0 {
config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn
}
if config.MaxReceiveWindow == 0 {
config.MaxReceiveWindow = defaultReceiveWindow
}
qc, err := forwarder.NewQUICClient(localAddr, remoteAddr, config.Name, config.TLSConfig,
config.Speed.SendBPS, config.Speed.ReceiveBPS,
config.MaxReceiveWindowPerConnection, config.MaxReceiveWindow,
forwarder.CongestionFactory(config.CongestionFactory),
func(addr net.Addr, banner string, cSend uint64, cRecv uint64) {
if callbacks.ServerConnectedCallback != nil {
callbacks.ServerConnectedCallback(addr, banner, cSend, cRecv)
}
},
func(err error) {
if callbacks.ServerErrorCallback != nil {
callbacks.ServerErrorCallback(err)
}
},
func(addr net.Addr) {
if callbacks.NewTCPConnectionCallback != nil {
callbacks.NewTCPConnectionCallback(addr)
}
},
func(addr net.Addr, err error) {
if callbacks.TCPConnectionClosedCallback != nil {
callbacks.TCPConnectionClosedCallback(addr, err)
}
},
)
if err != nil {
return nil, err
}
return &client{qc: qc}, nil
}
func (c *client) Stats() Stats {
addr, in, out := c.qc.Stats()
return Stats{
RemoteAddr: addr,
inboundBytes: in,
outboundBytes: out,
}
}
func (c *client) Close() error {
return c.Close()
}

View file

@ -1,89 +0,0 @@
package forwarder
import (
"crypto/tls"
"github.com/tobyxdd/hysteria/internal/forwarder"
"net"
)
type CongestionFactory forwarder.CongestionFactory
// A server can support multiple forwarding entries (listenAddr/remoteAddr pairs)
type Server interface {
Add(listenAddr, remoteAddr string) error
Remove(listenAddr string) error
Stats() map[string]Stats
}
// An empty ServerConfig is a valid one
type ServerConfig struct {
// A banner message that will be sent to the client after the connection is established.
// No message if not set.
BannerMessage string
// TLSConfig is used to configure the TLS server.
// Use an insecure self-signed certificate if not set.
TLSConfig *tls.Config
// MaxSpeedPerClient is the maximum allowed sending and receiving speed for each client.
// Sending speed will never exceed this limit, even if a client demands a larger value.
// No restrictions if not set.
MaxSpeedPerClient *Speed
// Corresponds to MaxReceiveStreamFlowControlWindow in QUIC.
MaxReceiveWindowPerConnection uint64
// Corresponds to MaxReceiveConnectionFlowControlWindow in QUIC.
MaxReceiveWindowPerClient uint64
// Max number of simultaneous connections allowed for a client
MaxConnectionPerClient int
// Congestion factory
CongestionFactory CongestionFactory
}
type ServerCallbacks struct {
ClientConnectedCallback func(listenAddr string, clientAddr net.Addr, name string, sSend uint64, sRecv uint64)
ClientDisconnectedCallback func(listenAddr string, clientAddr net.Addr, name string, err error)
ClientNewStreamCallback func(listenAddr string, clientAddr net.Addr, name string, id int)
ClientStreamClosedCallback func(listenAddr string, clientAddr net.Addr, name string, id int, err error)
TCPErrorCallback func(listenAddr string, remoteAddr string, err error)
}
// A client supports one forwarding entry
type Client interface {
Stats() Stats
Close() error
}
// An empty ClientConfig is NOT a valid one, as Speed must be set
type ClientConfig struct {
// A client can report its name to the server after the connection is established.
// No name if not set.
Name string
// TLSConfig is used to configure the TLS client.
// Use default settings if not set.
TLSConfig *tls.Config
// Speed reported by the client when negotiating with the server.
// The actual speed will also depend on the configuration of the server.
Speed *Speed
// Corresponds to MaxReceiveStreamFlowControlWindow in QUIC.
MaxReceiveWindowPerConnection uint64
// Corresponds to MaxReceiveConnectionFlowControlWindow in QUIC.
MaxReceiveWindow uint64
// Congestion factory
CongestionFactory CongestionFactory
}
type ClientCallbacks struct {
ServerConnectedCallback func(addr net.Addr, banner string, cSend uint64, cRecv uint64)
ServerErrorCallback func(err error)
NewTCPConnectionCallback func(addr net.Addr)
TCPConnectionClosedCallback func(addr net.Addr, err error)
}
type Speed struct {
SendBPS uint64
ReceiveBPS uint64
}
type Stats struct {
RemoteAddr string
inboundBytes uint64
outboundBytes uint64
}

View file

@ -1,9 +0,0 @@
package forwarder
const (
TLSAppProtocol = "hysteria-forwarder"
defaultReceiveWindowConn = 33554432
defaultReceiveWindow = 67108864
defaultMaxClientConn = 100
)

View file

@ -1,119 +0,0 @@
package forwarder
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"github.com/tobyxdd/hysteria/internal/forwarder"
"math/big"
"net"
)
type server struct {
config ServerConfig
callbacks ServerCallbacks
entries map[string]*forwarder.QUICServer
}
func NewServer(config ServerConfig, callbacks ServerCallbacks) Server {
// Fix config first
if config.TLSConfig == nil {
config.TLSConfig = generateInsecureTLSConfig()
}
if config.MaxSpeedPerClient == nil {
config.MaxSpeedPerClient = &Speed{0, 0}
}
if config.MaxReceiveWindowPerConnection == 0 {
config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn
}
if config.MaxReceiveWindowPerClient == 0 {
config.MaxReceiveWindowPerClient = defaultReceiveWindow
}
if config.MaxConnectionPerClient <= 0 {
config.MaxConnectionPerClient = defaultMaxClientConn
}
return &server{config: config, callbacks: callbacks, entries: make(map[string]*forwarder.QUICServer)}
}
func (s *server) Add(listenAddr, remoteAddr string) error {
qs, err := forwarder.NewQUICServer(listenAddr, remoteAddr, s.config.BannerMessage, s.config.TLSConfig,
s.config.MaxSpeedPerClient.SendBPS, s.config.MaxSpeedPerClient.ReceiveBPS,
s.config.MaxReceiveWindowPerConnection, s.config.MaxReceiveWindowPerClient,
s.config.MaxConnectionPerClient, forwarder.CongestionFactory(s.config.CongestionFactory),
func(addr net.Addr, name string, sSend uint64, sRecv uint64) {
if s.callbacks.ClientConnectedCallback != nil {
s.callbacks.ClientConnectedCallback(listenAddr, addr, name, sSend, sRecv)
}
},
func(addr net.Addr, name string, err error) {
if s.callbacks.ClientDisconnectedCallback != nil {
s.callbacks.ClientDisconnectedCallback(listenAddr, addr, name, err)
}
},
func(addr net.Addr, name string, id int) {
if s.callbacks.ClientNewStreamCallback != nil {
s.callbacks.ClientNewStreamCallback(listenAddr, addr, name, id)
}
},
func(addr net.Addr, name string, id int, err error) {
if s.callbacks.ClientStreamClosedCallback != nil {
s.callbacks.ClientStreamClosedCallback(listenAddr, addr, name, id, err)
}
},
func(remoteAddr string, err error) {
if s.callbacks.TCPErrorCallback != nil {
s.callbacks.TCPErrorCallback(listenAddr, remoteAddr, err)
}
},
)
if err != nil {
return err
}
s.entries[listenAddr] = qs
return nil
}
func (s *server) Remove(listenAddr string) error {
defer delete(s.entries, listenAddr)
if qs, ok := s.entries[listenAddr]; ok && qs != nil {
return qs.Close()
}
return nil
}
func (s *server) Stats() map[string]Stats {
r := make(map[string]Stats, len(s.entries))
for laddr, sv := range s.entries {
addr, in, out := sv.Stats()
r[laddr] = Stats{
RemoteAddr: addr,
inboundBytes: in,
outboundBytes: out,
}
}
return r
}
func generateInsecureTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{TLSAppProtocol},
}
}