diff --git a/app/cmd/client.go b/app/cmd/client.go index 8c593c0..0512867 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/hex" "errors" + "github.com/apernet/hysteria/extras/correctnet" "net" "os" "strconv" @@ -504,7 +505,7 @@ func clientSOCKS5(config socks5Config, c client.Client) error { if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := net.Listen("tcp", config.Listen) + l, err := correctnet.Listen("tcp", config.Listen) if err != nil { return configError{Field: "listen", Err: err} } @@ -529,7 +530,7 @@ func clientHTTP(config httpConfig, c client.Client) error { if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := net.Listen("tcp", config.Listen) + l, err := correctnet.Listen("tcp", config.Listen) if err != nil { return configError{Field: "listen", Err: err} } @@ -562,7 +563,7 @@ func clientTCPForwarding(entries []tcpForwardingEntry, c client.Client) error { if e.Remote == "" { return configError{Field: "remote", Err: errors.New("remote address is empty")} } - l, err := net.Listen("tcp", e.Listen) + l, err := correctnet.Listen("tcp", e.Listen) if err != nil { return configError{Field: "listen", Err: err} } @@ -589,7 +590,7 @@ func clientUDPForwarding(entries []udpForwardingEntry, c client.Client) error { if e.Remote == "" { return configError{Field: "remote", Err: errors.New("remote address is empty")} } - l, err := net.ListenPacket("udp", e.Listen) + l, err := correctnet.ListenPacket("udp", e.Listen) if err != nil { return configError{Field: "listen", Err: err} } diff --git a/app/cmd/server.go b/app/cmd/server.go index 9ec95e9..6784065 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/apernet/hysteria/extras/correctnet" "net" "net/http" "net/http/httputil" @@ -219,7 +220,7 @@ func (c *serverConfig) fillConn(hyConfig *server.Config) error { if err != nil { return configError{Field: "listen", Err: err} } - conn, err := net.ListenUDP("udp", uAddr) + conn, err := correctnet.ListenUDP("udp", uAddr) if err != nil { return configError{Field: "listen", Err: err} } @@ -752,7 +753,7 @@ func runServer(cmd *cobra.Command, args []string) { func runTrafficStatsServer(listen string, handler http.Handler) { logger.Info("traffic stats server up and running", zap.String("listen", listen)) - if err := http.ListenAndServe(listen, handler); err != nil { + if err := correctnet.HTTPListenAndServe(listen, handler); err != nil { logger.Fatal("failed to serve traffic stats", zap.Error(err)) } } diff --git a/extras/correctnet/correctnet.go b/extras/correctnet/correctnet.go new file mode 100644 index 0000000..3ce0681 --- /dev/null +++ b/extras/correctnet/correctnet.go @@ -0,0 +1,92 @@ +package correctnet + +import ( + "net" + "net/http" + "strings" +) + +func extractIPFamily(ip net.IP) (family string) { + if len(ip) == 0 { + // real family independent wildcard address, such as ":443" + return "" + } + if p4 := ip.To4(); len(p4) == net.IPv4len { + return "4" + } + return "6" +} + +func tcpAddrNetwork(addr *net.TCPAddr) (network string) { + if addr == nil { + return "tcp" + } + return "tcp" + extractIPFamily(addr.IP) +} + +func udpAddrNetwork(addr *net.UDPAddr) (network string) { + if addr == nil { + return "udp" + } + return "udp" + extractIPFamily(addr.IP) +} + +func ipAddrNetwork(addr *net.IPAddr) (network string) { + if addr == nil { + return "ip" + } + return "ip" + extractIPFamily(addr.IP) +} + +func Listen(network string, address string) (net.Listener, error) { + if network == "tcp" { + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return ListenTCP(network, tcpAddr) + } + return net.Listen(network, address) +} + +func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) { + if network == "tcp" { + return net.ListenTCP(tcpAddrNetwork(laddr), laddr) + } + return net.ListenTCP(network, laddr) +} + +func ListenPacket(network string, address string) (listener net.PacketConn, err error) { + if network == "udp" { + udpAddr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + return ListenUDP(network, udpAddr) + } + if strings.HasPrefix(network, "ip:") { + proto := network[3:] + ipAddr, err := net.ResolveIPAddr(proto, address) + if err != nil { + return nil, err + } + return net.ListenIP(ipAddrNetwork(ipAddr)+":"+proto, ipAddr) + } + return net.ListenPacket(network, address) +} + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + if network == "udp" { + return net.ListenUDP(udpAddrNetwork(laddr), laddr) + } + return net.ListenUDP(network, laddr) +} + +func HTTPListenAndServe(address string, handler http.Handler) (err error) { + listener, err := Listen("tcp", address) + if err != nil { + return err + } + defer listener.Close() + return http.Serve(listener, handler) +} diff --git a/extras/masq/server.go b/extras/masq/server.go index 5600f7c..91b4334 100644 --- a/extras/masq/server.go +++ b/extras/masq/server.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/tls" "fmt" + "github.com/apernet/hysteria/extras/correctnet" "net" "net/http" ) @@ -20,7 +21,7 @@ type MasqTCPServer struct { } func (s *MasqTCPServer) ListenAndServeHTTP(addr string) error { - return http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return correctnet.HTTPListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if s.ForceHTTPS { if s.HTTPSPort == 0 || s.HTTPSPort == 443 { // Omit port if it's the default @@ -42,7 +43,12 @@ func (s *MasqTCPServer) ListenAndServeHTTPS(addr string) error { }), TLSConfig: s.TLSConfig, } - return server.ListenAndServeTLS("", "") + listener, err := correctnet.Listen("tcp", addr) + if err != nil { + return err + } + defer listener.Close() + return server.ServeTLS(listener, "", "") } var _ http.ResponseWriter = (*altSvcHijackResponseWriter)(nil)