diff --git a/app/cmd/client.go b/app/cmd/client.go index 42199d7..8e7317a 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -21,6 +21,7 @@ import ( "github.com/apernet/hysteria/app/internal/http" "github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/app/internal/tproxy" + "github.com/apernet/hysteria/app/internal/utils" "github.com/apernet/hysteria/core/client" "github.com/apernet/hysteria/extras/obfs" ) @@ -222,13 +223,13 @@ func (c *clientConfig) fillBandwidthConfig(hyConfig *client.Config) error { // New core now allows users to omit bandwidth values and use built-in congestion control var err error if c.Bandwidth.Up != "" { - hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up) + hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up) if err != nil { return configError{Field: "bandwidth.up", Err: err} } } if c.Bandwidth.Down != "" { - hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down) + hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down) if err != nil { return configError{Field: "bandwidth.down", Err: err} } @@ -369,7 +370,7 @@ func runClient(cmd *cobra.Command, args []string) { uri := config.URI() logger.Info("use this URI to share your server", zap.String("uri", uri)) if showQR { - printQR(uri) + utils.PrintQR(uri) } // Modes @@ -594,6 +595,15 @@ func parseServerAddrString(addrStr string) (host, hostPort string) { return h, addrStr } +// normalizeCertHash normalizes a certificate hash string. +// It converts all characters to lowercase and removes possible separators such as ":" and "-". +func normalizeCertHash(hash string) string { + r := strings.ToLower(hash) + r = strings.ReplaceAll(r, ":", "") + r = strings.ReplaceAll(r, "-", "") + return r +} + // obfsConnFactory adds obfuscation to a function that creates net.PacketConn. type obfsConnFactory struct { NewFunc func(addr net.Addr) (net.PacketConn, error) diff --git a/app/cmd/errors.go b/app/cmd/errors.go new file mode 100644 index 0000000..3d0234a --- /dev/null +++ b/app/cmd/errors.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "fmt" +) + +type configError struct { + Field string + Err error +} + +func (e configError) Error() string { + return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err) +} + +func (e configError) Unwrap() error { + return e.Err +} diff --git a/app/cmd/server.go b/app/cmd/server.go index 435d2c0..940a6a7 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/viper" "go.uber.org/zap" + "github.com/apernet/hysteria/app/internal/utils" "github.com/apernet/hysteria/core/server" "github.com/apernet/hysteria/extras/auth" "github.com/apernet/hysteria/extras/obfs" @@ -378,7 +379,7 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { if c.ACL.File != "" && len(c.ACL.Inline) > 0 { return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")} } - gLoader := &geoipLoader{ + gLoader := &utils.GeoIPLoader{ Filename: c.ACL.GeoIP, DownloadFunc: geoipDownloadFunc, DownloadErrFunc: geoipDownloadErrFunc, @@ -442,13 +443,13 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { func (c *serverConfig) fillBandwidthConfig(hyConfig *server.Config) error { var err error if c.Bandwidth.Up != "" { - hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up) + hyConfig.BandwidthConfig.MaxTx, err = utils.ConvBandwidth(c.Bandwidth.Up) if err != nil { return configError{Field: "bandwidth.up", Err: err} } } if c.Bandwidth.Down != "" { - hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down) + hyConfig.BandwidthConfig.MaxRx, err = utils.ConvBandwidth(c.Bandwidth.Down) if err != nil { return configError{Field: "bandwidth.down", Err: err} } diff --git a/app/cmd/utils.go b/app/cmd/utils.go deleted file mode 100644 index 516340a..0000000 --- a/app/cmd/utils.go +++ /dev/null @@ -1,120 +0,0 @@ -package cmd - -import ( - "fmt" - "io" - "net/http" - "os" - "strings" - - "github.com/apernet/hysteria/extras/utils" - "github.com/mdp/qrterminal/v3" - "github.com/oschwald/geoip2-golang" -) - -const ( - geoipDefaultFilename = "GeoLite2-Country.mmdb" - geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb" -) - -// convBandwidth handles both string and int types for bandwidth. -// When using string, it will be parsed as a bandwidth string with units. -// When using int, it will be parsed as a raw bandwidth in bytes per second. -// It does NOT support float types. -func convBandwidth(bw interface{}) (uint64, error) { - switch bwT := bw.(type) { - case string: - return utils.StringToBps(bwT) - case int: - return uint64(bwT), nil - default: - return 0, fmt.Errorf("invalid type %T for bandwidth", bwT) - } -} - -func printQR(str string) { - qrterminal.GenerateWithConfig(str, qrterminal.Config{ - Level: qrterminal.L, - Writer: os.Stdout, - BlackChar: qrterminal.BLACK, - WhiteChar: qrterminal.WHITE, - }) -} - -type configError struct { - Field string - Err error -} - -func (e configError) Error() string { - return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err) -} - -func (e configError) Unwrap() error { - return e.Err -} - -// geoipLoader provides the on-demand GeoIP database loading function required by the ACL engine. -type geoipLoader struct { - Filename string - DownloadFunc func(filename, url string) // Called when downloading the GeoIP database. - DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails. - - db *geoip2.Reader -} - -func (l *geoipLoader) download() error { - resp, err := http.Get(geoipDownloadURL) - if err != nil { - return err - } - defer resp.Body.Close() - - f, err := os.Create(geoipDefaultFilename) - if err != nil { - return err - } - defer f.Close() - - _, err = io.Copy(f, resp.Body) - return err -} - -func (l *geoipLoader) Load() *geoip2.Reader { - if l.db == nil { - if l.Filename == "" { - // Filename not specified, try default. - if _, err := os.Stat(geoipDefaultFilename); err == nil { - // Default already exists, just use it. - l.Filename = geoipDefaultFilename - } else if os.IsNotExist(err) { - // Default doesn't exist, download it. - l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL) - err := l.download() - l.DownloadErrFunc(err) - if err != nil { - return nil - } - l.Filename = geoipDefaultFilename - } else { - // Other error - return nil - } - } - db, err := geoip2.Open(l.Filename) - if err != nil { - return nil - } - l.db = db - } - return l.db -} - -// normalizeCertHash normalizes a certificate hash string. -// It converts all characters to lowercase and removes possible separators such as ":" and "-". -func normalizeCertHash(hash string) string { - r := strings.ToLower(hash) - r = strings.ReplaceAll(r, ":", "") - r = strings.ReplaceAll(r, "-", "") - return r -} diff --git a/extras/utils/bpsconv.go b/app/internal/utils/bpsconv.go similarity index 66% rename from extras/utils/bpsconv.go rename to app/internal/utils/bpsconv.go index 9147ef2..97a7d3f 100644 --- a/extras/utils/bpsconv.go +++ b/app/internal/utils/bpsconv.go @@ -2,6 +2,7 @@ package utils import ( "errors" + "fmt" "strconv" "strings" ) @@ -50,3 +51,18 @@ func StringToBps(s string) (uint64, error) { return 0, errors.New("unsupported unit") } } + +// ConvBandwidth handles both string and int types for bandwidth. +// When using string, it will be parsed as a bandwidth string with units. +// When using int, it will be parsed as a raw bandwidth in bytes per second. +// It does NOT support float types. +func ConvBandwidth(bw interface{}) (uint64, error) { + switch bwT := bw.(type) { + case string: + return StringToBps(bwT) + case int: + return uint64(bwT), nil + default: + return 0, fmt.Errorf("invalid type %T for bandwidth", bwT) + } +} diff --git a/extras/utils/bpsconv_test.go b/app/internal/utils/bpsconv_test.go similarity index 100% rename from extras/utils/bpsconv_test.go rename to app/internal/utils/bpsconv_test.go diff --git a/app/internal/utils/geoip.go b/app/internal/utils/geoip.go new file mode 100644 index 0000000..2144ecb --- /dev/null +++ b/app/internal/utils/geoip.go @@ -0,0 +1,70 @@ +package utils + +import ( + "io" + "net/http" + "os" + + "github.com/oschwald/geoip2-golang" +) + +const ( + geoipDefaultFilename = "GeoLite2-Country.mmdb" + geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb" +) + +// GeoIPLoader provides the on-demand GeoIP database loading function required by the ACL engine. +type GeoIPLoader struct { + Filename string + DownloadFunc func(filename, url string) // Called when downloading the GeoIP database. + DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails. + + db *geoip2.Reader +} + +func (l *GeoIPLoader) download() error { + resp, err := http.Get(geoipDownloadURL) + if err != nil { + return err + } + defer resp.Body.Close() + + f, err := os.Create(geoipDefaultFilename) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, resp.Body) + return err +} + +func (l *GeoIPLoader) Load() *geoip2.Reader { + if l.db == nil { + if l.Filename == "" { + // Filename not specified, try default. + if _, err := os.Stat(geoipDefaultFilename); err == nil { + // Default already exists, just use it. + l.Filename = geoipDefaultFilename + } else if os.IsNotExist(err) { + // Default doesn't exist, download it. + l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL) + err := l.download() + l.DownloadErrFunc(err) + if err != nil { + return nil + } + l.Filename = geoipDefaultFilename + } else { + // Other error + return nil + } + } + db, err := geoip2.Open(l.Filename) + if err != nil { + return nil + } + l.db = db + } + return l.db +} diff --git a/app/internal/utils/qr.go b/app/internal/utils/qr.go new file mode 100644 index 0000000..f0c1d39 --- /dev/null +++ b/app/internal/utils/qr.go @@ -0,0 +1,16 @@ +package utils + +import ( + "os" + + "github.com/mdp/qrterminal/v3" +) + +func PrintQR(str string) { + qrterminal.GenerateWithConfig(str, qrterminal.Config{ + Level: qrterminal.L, + Writer: os.Stdout, + BlackChar: qrterminal.BLACK, + WhiteChar: qrterminal.WHITE, + }) +}