hysteria/app/cmd/server.go
2023-05-25 20:24:24 -07:00

275 lines
8.3 KiB
Go

package cmd
import (
"context"
"crypto/tls"
"errors"
"net"
"strings"
"github.com/apernet/hysteria/core/server"
"github.com/apernet/hysteria/extras/auth"
"github.com/caddyserver/certmagic"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
)
var serverCmd = &cobra.Command{
Use: "server",
Short: "Server mode",
Run: runServer,
}
func init() {
rootCmd.AddCommand(serverCmd)
initServerConfigDefaults()
}
func initServerConfigDefaults() {
viper.SetDefault("listen", ":443")
}
func runServer(cmd *cobra.Command, args []string) {
logger.Info("server mode")
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read server config", zap.Error(err))
}
config, err := viperToServerConfig()
if err != nil {
logger.Fatal("failed to parse server config", zap.Error(err))
}
s, err := server.NewServer(config)
if err != nil {
logger.Fatal("failed to initialize server", zap.Error(err))
}
logger.Info("server up and running")
if err := s.Serve(); err != nil {
logger.Fatal("failed to serve", zap.Error(err))
}
}
func viperToServerConfig() (*server.Config, error) {
// Conn
conn, err := viperToServerConn()
if err != nil {
return nil, err
}
// TLS
tlsConfig, err := viperToServerTLSConfig()
if err != nil {
return nil, err
}
// QUIC
quicConfig := viperToServerQUICConfig()
// Bandwidth
bwConfig, err := viperToServerBandwidthConfig()
if err != nil {
return nil, err
}
// Disable UDP
disableUDP := viper.GetBool("disableUDP")
// Authenticator
authenticator, err := viperToAuthenticator()
if err != nil {
return nil, err
}
// Config
config := &server.Config{
TLSConfig: tlsConfig,
QUICConfig: quicConfig,
Conn: conn,
Outbound: nil, // TODO
BandwidthConfig: bwConfig,
DisableUDP: disableUDP,
Authenticator: authenticator,
EventLogger: &serverLogger{},
MasqHandler: nil, // TODO
}
return config, nil
}
func viperToServerConn() (net.PacketConn, error) {
listen := viper.GetString("listen")
if listen == "" {
return nil, configError{Field: "listen", Err: errors.New("empty listen address")}
}
uAddr, err := net.ResolveUDPAddr("udp", listen)
if err != nil {
return nil, configError{Field: "listen", Err: err}
}
conn, err := net.ListenUDP("udp", uAddr)
if err != nil {
return nil, configError{Field: "listen", Err: err}
}
return conn, nil
}
func viperToServerTLSConfig() (server.TLSConfig, error) {
vTLS, vACME := viper.Sub("tls"), viper.Sub("acme")
if vTLS == nil && vACME == nil {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
}
if vTLS != nil && vACME != nil {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
}
if vTLS != nil {
return viperToServerTLSConfigLocal(vTLS)
} else {
return viperToServerTLSConfigACME(vACME)
}
}
func viperToServerTLSConfigLocal(v *viper.Viper) (server.TLSConfig, error) {
certPath, keyPath := v.GetString("cert"), v.GetString("key")
if certPath == "" || keyPath == "" {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("empty cert or key path")}
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return server.TLSConfig{}, configError{Field: "tls", Err: err}
}
return server.TLSConfig{
Certificates: []tls.Certificate{cert},
}, nil
}
func viperToServerTLSConfigACME(v *viper.Viper) (server.TLSConfig, error) {
dataDir := v.GetString("dir")
if dataDir == "" {
dataDir = "acme"
}
cfg := &certmagic.Config{
RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
KeySource: certmagic.DefaultKeyGenerator,
Storage: &certmagic.FileStorage{Path: dataDir},
Logger: logger,
}
issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
Email: v.GetString("email"),
Agreed: true,
DisableHTTPChallenge: v.GetBool("disableHTTP"),
DisableTLSALPNChallenge: v.GetBool("disableTLSALPN"),
AltHTTPPort: v.GetInt("altHTTPPort"),
AltTLSALPNPort: v.GetInt("altTLSALPNPort"),
Logger: logger,
})
switch strings.ToLower(v.GetString("ca")) {
case "letsencrypt", "le", "":
// Default to Let's Encrypt
issuer.CA = certmagic.LetsEncryptProductionCA
case "zerossl", "zero":
issuer.CA = certmagic.ZeroSSLProductionCA
default:
return server.TLSConfig{}, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
}
cfg.Issuers = []certmagic.Issuer{issuer}
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return cfg, nil
},
Logger: logger,
})
cfg = certmagic.New(cache, *cfg)
domains := v.GetStringSlice("domains")
if len(domains) == 0 {
return server.TLSConfig{}, configError{Field: "acme.domains", Err: errors.New("empty domains")}
}
err := cfg.ManageSync(context.Background(), domains)
if err != nil {
return server.TLSConfig{}, configError{Field: "acme", Err: err}
}
return server.TLSConfig{
GetCertificate: cfg.GetCertificate,
}, nil
}
func viperToServerQUICConfig() server.QUICConfig {
return server.QUICConfig{
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"),
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"),
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"),
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"),
MaxIncomingStreams: viper.GetInt64("quic.maxIncomingStreams"),
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
}
}
func viperToServerBandwidthConfig() (server.BandwidthConfig, error) {
bw := server.BandwidthConfig{}
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
if upStr != "" {
up, err := convBandwidth(upStr)
if err != nil {
return server.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
}
bw.MaxTx = up
}
if downStr != "" {
down, err := convBandwidth(downStr)
if err != nil {
return server.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
}
bw.MaxRx = down
}
return bw, nil
}
func viperToAuthenticator() (server.Authenticator, error) {
authType := viper.GetString("auth.type")
if authType == "" {
return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")}
}
switch authType {
case "password":
pw := viper.GetString("auth.password")
if pw == "" {
return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")}
}
return &auth.PasswordAuthenticator{Password: pw}, nil
default:
return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
}
}
type serverLogger struct{}
func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
logger.Info("client connected", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint64("tx", tx))
}
func (l *serverLogger) Disconnect(addr net.Addr, id string, err error) {
logger.Info("client disconnected", zap.String("addr", addr.String()), zap.String("id", id), zap.Error(err))
}
func (l *serverLogger) TCPRequest(addr net.Addr, id, reqAddr string) {
logger.Debug("TCP request", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
}
func (l *serverLogger) TCPError(addr net.Addr, id, reqAddr string, err error) {
if err == nil {
logger.Debug("TCP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
} else {
logger.Error("TCP error", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr), zap.Error(err))
}
}
func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
}
func (l *serverLogger) UDPError(addr net.Addr, id string, sessionID uint32, err error) {
if err == nil {
logger.Debug("UDP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
} else {
logger.Error("UDP error", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID), zap.Error(err))
}
}