chore: move code around

This commit is contained in:
Toby 2023-08-23 16:26:38 -07:00
parent 3c3c2a51a8
commit 332d2ea32d
8 changed files with 137 additions and 126 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/apernet/hysteria/app/internal/http"
"github.com/apernet/hysteria/app/internal/socks5"
"github.com/apernet/hysteria/app/internal/tproxy"
"github.com/apernet/hysteria/app/internal/utils"
"github.com/apernet/hysteria/core/client"
"github.com/apernet/hysteria/extras/obfs"
)
@ -222,13 +223,13 @@ func (c *clientConfig) fillBandwidthConfig(hyConfig *client.Config) error {
// New core now allows users to omit bandwidth values and use built-in congestion control
var err error
if c.Bandwidth.Up != "" {
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up)
if err != nil {
return configError{Field: "bandwidth.up", Err: err}
}
}
if c.Bandwidth.Down != "" {
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down)
if err != nil {
return configError{Field: "bandwidth.down", Err: err}
}
@ -369,7 +370,7 @@ func runClient(cmd *cobra.Command, args []string) {
uri := config.URI()
logger.Info("use this URI to share your server", zap.String("uri", uri))
if showQR {
printQR(uri)
utils.PrintQR(uri)
}
// Modes
@ -594,6 +595,15 @@ func parseServerAddrString(addrStr string) (host, hostPort string) {
return h, addrStr
}
// normalizeCertHash normalizes a certificate hash string.
// It converts all characters to lowercase and removes possible separators such as ":" and "-".
func normalizeCertHash(hash string) string {
r := strings.ToLower(hash)
r = strings.ReplaceAll(r, ":", "")
r = strings.ReplaceAll(r, "-", "")
return r
}
// obfsConnFactory adds obfuscation to a function that creates net.PacketConn.
type obfsConnFactory struct {
NewFunc func(addr net.Addr) (net.PacketConn, error)

18
app/cmd/errors.go Normal file
View file

@ -0,0 +1,18 @@
package cmd
import (
"fmt"
)
type configError struct {
Field string
Err error
}
func (e configError) Error() string {
return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err)
}
func (e configError) Unwrap() error {
return e.Err
}

View file

@ -16,6 +16,7 @@ import (
"github.com/spf13/viper"
"go.uber.org/zap"
"github.com/apernet/hysteria/app/internal/utils"
"github.com/apernet/hysteria/core/server"
"github.com/apernet/hysteria/extras/auth"
"github.com/apernet/hysteria/extras/obfs"
@ -378,7 +379,7 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
if c.ACL.File != "" && len(c.ACL.Inline) > 0 {
return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")}
}
gLoader := &geoipLoader{
gLoader := &utils.GeoIPLoader{
Filename: c.ACL.GeoIP,
DownloadFunc: geoipDownloadFunc,
DownloadErrFunc: geoipDownloadErrFunc,
@ -442,13 +443,13 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error {
func (c *serverConfig) fillBandwidthConfig(hyConfig *server.Config) error {
var err error
if c.Bandwidth.Up != "" {
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up)
if err != nil {
return configError{Field: "bandwidth.up", Err: err}
}
}
if c.Bandwidth.Down != "" {
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down)
if err != nil {
return configError{Field: "bandwidth.down", Err: err}
}

View file

@ -1,120 +0,0 @@
package cmd
import (
"fmt"
"io"
"net/http"
"os"
"strings"
"github.com/apernet/hysteria/extras/utils"
"github.com/mdp/qrterminal/v3"
"github.com/oschwald/geoip2-golang"
)
const (
geoipDefaultFilename = "GeoLite2-Country.mmdb"
geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb"
)
// convBandwidth handles both string and int types for bandwidth.
// When using string, it will be parsed as a bandwidth string with units.
// When using int, it will be parsed as a raw bandwidth in bytes per second.
// It does NOT support float types.
func convBandwidth(bw interface{}) (uint64, error) {
switch bwT := bw.(type) {
case string:
return utils.StringToBps(bwT)
case int:
return uint64(bwT), nil
default:
return 0, fmt.Errorf("invalid type %T for bandwidth", bwT)
}
}
func printQR(str string) {
qrterminal.GenerateWithConfig(str, qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
BlackChar: qrterminal.BLACK,
WhiteChar: qrterminal.WHITE,
})
}
type configError struct {
Field string
Err error
}
func (e configError) Error() string {
return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err)
}
func (e configError) Unwrap() error {
return e.Err
}
// geoipLoader provides the on-demand GeoIP database loading function required by the ACL engine.
type geoipLoader struct {
Filename string
DownloadFunc func(filename, url string) // Called when downloading the GeoIP database.
DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails.
db *geoip2.Reader
}
func (l *geoipLoader) download() error {
resp, err := http.Get(geoipDownloadURL)
if err != nil {
return err
}
defer resp.Body.Close()
f, err := os.Create(geoipDefaultFilename)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func (l *geoipLoader) Load() *geoip2.Reader {
if l.db == nil {
if l.Filename == "" {
// Filename not specified, try default.
if _, err := os.Stat(geoipDefaultFilename); err == nil {
// Default already exists, just use it.
l.Filename = geoipDefaultFilename
} else if os.IsNotExist(err) {
// Default doesn't exist, download it.
l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL)
err := l.download()
l.DownloadErrFunc(err)
if err != nil {
return nil
}
l.Filename = geoipDefaultFilename
} else {
// Other error
return nil
}
}
db, err := geoip2.Open(l.Filename)
if err != nil {
return nil
}
l.db = db
}
return l.db
}
// normalizeCertHash normalizes a certificate hash string.
// It converts all characters to lowercase and removes possible separators such as ":" and "-".
func normalizeCertHash(hash string) string {
r := strings.ToLower(hash)
r = strings.ReplaceAll(r, ":", "")
r = strings.ReplaceAll(r, "-", "")
return r
}

View file

@ -2,6 +2,7 @@ package utils
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -50,3 +51,18 @@ func StringToBps(s string) (uint64, error) {
return 0, errors.New("unsupported unit")
}
}
// ConvBandwidth handles both string and int types for bandwidth.
// When using string, it will be parsed as a bandwidth string with units.
// When using int, it will be parsed as a raw bandwidth in bytes per second.
// It does NOT support float types.
func ConvBandwidth(bw interface{}) (uint64, error) {
switch bwT := bw.(type) {
case string:
return StringToBps(bwT)
case int:
return uint64(bwT), nil
default:
return 0, fmt.Errorf("invalid type %T for bandwidth", bwT)
}
}

View file

@ -0,0 +1,70 @@
package utils
import (
"io"
"net/http"
"os"
"github.com/oschwald/geoip2-golang"
)
const (
geoipDefaultFilename = "GeoLite2-Country.mmdb"
geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb"
)
// GeoIPLoader provides the on-demand GeoIP database loading function required by the ACL engine.
type GeoIPLoader struct {
Filename string
DownloadFunc func(filename, url string) // Called when downloading the GeoIP database.
DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails.
db *geoip2.Reader
}
func (l *GeoIPLoader) download() error {
resp, err := http.Get(geoipDownloadURL)
if err != nil {
return err
}
defer resp.Body.Close()
f, err := os.Create(geoipDefaultFilename)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func (l *GeoIPLoader) Load() *geoip2.Reader {
if l.db == nil {
if l.Filename == "" {
// Filename not specified, try default.
if _, err := os.Stat(geoipDefaultFilename); err == nil {
// Default already exists, just use it.
l.Filename = geoipDefaultFilename
} else if os.IsNotExist(err) {
// Default doesn't exist, download it.
l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL)
err := l.download()
l.DownloadErrFunc(err)
if err != nil {
return nil
}
l.Filename = geoipDefaultFilename
} else {
// Other error
return nil
}
}
db, err := geoip2.Open(l.Filename)
if err != nil {
return nil
}
l.db = db
}
return l.db
}

16
app/internal/utils/qr.go Normal file
View file

@ -0,0 +1,16 @@
package utils
import (
"os"
"github.com/mdp/qrterminal/v3"
)
func PrintQR(str string) {
qrterminal.GenerateWithConfig(str, qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
BlackChar: qrterminal.BLACK,
WhiteChar: qrterminal.WHITE,
})
}