diff --git a/cmd/client.go b/cmd/client.go index 9bfbac9..e4e53f7 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -12,7 +12,6 @@ import ( "time" "github.com/apernet/hysteria/pkg/pktconns" - "github.com/apernet/hysteria/pkg/pmtud" "github.com/apernet/hysteria/pkg/redirect" "github.com/oschwald/geoip2-golang" @@ -29,14 +28,6 @@ import ( "github.com/sirupsen/logrus" ) -var clientPacketConnFuncFactoryMap = map[string]pktconns.ClientPacketConnFuncFactory{ - "": pktconns.NewClientUDPConnFunc, - "udp": pktconns.NewClientUDPConnFunc, - "wechat": pktconns.NewClientWeChatConnFunc, - "wechat-video": pktconns.NewClientWeChatConnFunc, - "faketcp": pktconns.NewClientFakeTCPConnFunc, -} - func client(config *clientConfig) { logrus.WithField("config", config.String()).Info("Client configuration loaded") config.Fill() // Fill default values @@ -96,7 +87,7 @@ func client(config *clientConfig) { auth = []byte(config.AuthString) } // Packet conn - pktConnFuncFactory := clientPacketConnFuncFactoryMap[config.Protocol] + pktConnFuncFactory := pktconns.ClientPacketConnFuncFactoryMap[config.Protocol] if pktConnFuncFactory == nil { logrus.WithFields(logrus.Fields{ "protocol": config.Protocol, diff --git a/cmd/server.go b/cmd/server.go index f69b829..6f9ce6e 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -7,11 +7,10 @@ import ( "net/http" "time" - "github.com/apernet/hysteria/pkg/pktconns" - "github.com/apernet/hysteria/cmd/auth" "github.com/apernet/hysteria/pkg/acl" "github.com/apernet/hysteria/pkg/core" + "github.com/apernet/hysteria/pkg/pktconns" "github.com/apernet/hysteria/pkg/pmtud" "github.com/apernet/hysteria/pkg/sockopt" "github.com/apernet/hysteria/pkg/transport" @@ -23,14 +22,6 @@ import ( "github.com/yosuke-furukawa/json5/encoding/json5" ) -var serverPacketConnFuncFactoryMap = map[string]pktconns.ServerPacketConnFuncFactory{ - "": pktconns.NewServerUDPConnFunc, - "udp": pktconns.NewServerUDPConnFunc, - "wechat": pktconns.NewServerWeChatConnFunc, - "wechat-video": pktconns.NewServerWeChatConnFunc, - "faketcp": pktconns.NewServerFakeTCPConnFunc, -} - func server(config *serverConfig) { logrus.WithField("config", config.String()).Info("Server configuration loaded") config.Fill() // Fill default values @@ -207,7 +198,7 @@ func server(config *serverConfig) { }() } // Packet conn - pktConnFuncFactory := serverPacketConnFuncFactoryMap[config.Protocol] + pktConnFuncFactory := pktconns.ServerPacketConnFuncFactoryMap[config.Protocol] if pktConnFuncFactory == nil { logrus.WithField("protocol", config.Protocol).Fatal("Unsupported protocol") } diff --git a/pkg/core/client.go b/pkg/core/client.go index 25415ed..e48260d 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -189,7 +189,9 @@ func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error) // Temporary error, just return return nil, nil, err } - c.quicReconnectFunc(err) + if c.quicReconnectFunc != nil { + c.quicReconnectFunc(err) + } // Permanent error, need to reconnect if err := c.connect(); err != nil { // Still error, oops diff --git a/pkg/pktconns/funcs.go b/pkg/pktconns/funcs.go index 23427c4..7bdc409 100644 --- a/pkg/pktconns/funcs.go +++ b/pkg/pktconns/funcs.go @@ -21,6 +21,22 @@ type ( ServerPacketConnFuncFactory func(obfsPassword string) ServerPacketConnFunc ) +var ClientPacketConnFuncFactoryMap = map[string]ClientPacketConnFuncFactory{ + "": NewClientUDPConnFunc, + "udp": NewClientUDPConnFunc, + "wechat": NewClientWeChatConnFunc, + "wechat-video": NewClientWeChatConnFunc, + "faketcp": NewClientFakeTCPConnFunc, +} + +var ServerPacketConnFuncFactoryMap = map[string]ServerPacketConnFuncFactory{ + "": NewServerUDPConnFunc, + "udp": NewServerUDPConnFunc, + "wechat": NewServerWeChatConnFunc, + "wechat-video": NewServerWeChatConnFunc, + "faketcp": NewServerFakeTCPConnFunc, +} + func NewClientUDPConnFunc(obfsPassword string, hopInterval time.Duration) ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { diff --git a/sdk/client.go b/sdk/client.go new file mode 100644 index 0000000..dc356f5 --- /dev/null +++ b/sdk/client.go @@ -0,0 +1,249 @@ +// Package sdk provides an official API for integrating Hysteria client into other projects. +// It aims to be as stable & simple as possible, so that it can be easily maintained and +// widely adopted. +package sdk + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "net" + "time" + + "github.com/apernet/hysteria/pkg/core" + "github.com/lucas-clemente/quic-go" +) + +const ( + defaultALPN = "hysteria" + + defaultStreamReceiveWindow = 16777216 // 16 MB + defaultConnectionReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 40 MB + + defaultClientIdleTimeoutSec = 20 + + defaultClientHopIntervalSec = 10 +) + +type ( + Protocol string + ResolveFunc func(network string, address string) (net.Addr, error) +) + +const ( + ProtocolUDP Protocol = "udp" + ProtocolWeChat Protocol = "wechat" + ProtocolFakeTCP Protocol = "faketcp" +) + +// Client is a Hysteria client. +type Client interface { + // DialTCP dials a TCP connection to the specified address. + // The remote address must be in "host:port" format. + DialTCP(addr string) (net.Conn, error) + + // DialUDP dials a UDP connection. + // It is bound to a fixed port on the server side. + // Can be used to send and receive UDP packets to/from any address. + DialUDP() (HyUDPConn, error) + + // Close closes the client. + Close() error +} + +// HyUDPConn is a Hysteria-proxied UDP connection. +type HyUDPConn interface { + // ReadFrom reads a packet from the connection. + // It returns the data, the source address (in "host:port" format) and any error encountered. + ReadFrom() ([]byte, string, error) + + // WriteTo writes a packet to the connection. + // The remote address must be in "host:port" format. + WriteTo([]byte, string) error + + // Close closes the connection. + Close() error +} + +// ClientConfig is the configuration for a Hysteria client. +type ClientConfig struct { + // ServerAddress is the address of the Hysteria server. + // It must be in "host:port" format. + ServerAddress string + + // ResolveFunc is the function used to resolve the server address. + // If not set, the default resolver will be used. + ResolveFunc ResolveFunc + + // Protocol is the protocol to use. + // It must be one of the following: + // - ProtocolUDP + // - ProtocolWeChat + // - ProtocolFakeTCP + Protocol Protocol + + // Obfs is the obfuscation password. + // Empty = no obfuscation. + Obfs string + + // HopInterval is the port hopping interval. + // 0 = default 10s. + HopInterval time.Duration + + // Auth is the authentication payload to be sent to the server. + // It can be empty or nil if no authentication is required. + Auth []byte + + // SendBPS is the maximum sending speed in bytes per second. + // Required and cannot be 0. + SendBPS uint64 + + // RecvBPS is the maximum receiving speed in bytes per second. + // Required and cannot be 0. + RecvBPS uint64 + + // ALPN is the ALPN protocol to be used. + // Empty = default "hysteria". + ALPN string + + // ServerName is the SNI to be used. + // Empty = get from ServerAddress. + ServerName string + + // Insecure is whether to skip certificate verification. + // It is not recommended to set this to true. + Insecure bool + + // RootCAs is the root CA certificates to be used. + // Empty = use system default. + RootCAs *x509.CertPool + + // ReceiveWindowConn is the flow control receive window size for each connection. + // 0 = default 16MB. + ReceiveWindowConn uint64 + + // ReceiveWindow is the flow control receive window size for the whole client. + // 0 = default 40MB. + ReceiveWindow uint64 + + // HandshakeTimeout is the timeout for the initial handshake. + // 0 = default 5s. + HandshakeTimeout time.Duration + + // IdleTimeout is the timeout for idle connections. + // The client will send a heartbeat packet every 2/5 of this value. + // If the server does not respond within IdleTimeout, the connection will be closed. + // 0 = default 20s. + IdleTimeout time.Duration + + // DisableMTUDiscovery is whether to disable MTU discovery. + // Only disable this if you are having MTU issues. + DisableMTUDiscovery bool + + // TLSConfig, if not nil, will override all TLS-related fields above!!! + // Only set this if you know what you are doing. + TLSConfig *tls.Config + + // QUICConfig, if not nil, will override all QUIC-related fields above!!! + // Only set this if you know what you are doing. + QUICConfig *quic.Config +} + +// fill fills in the default values (if not set) for the configuration. +func (c *ClientConfig) fill() { + if c.ResolveFunc == nil { + c.ResolveFunc = func(network string, address string) (net.Addr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return net.ResolveTCPAddr(network, address) + case "udp", "udp4", "udp6": + return net.ResolveUDPAddr(network, address) + default: + return nil, errors.New("unsupported network type") + } + } + } + if c.Protocol == "" { + c.Protocol = ProtocolUDP + } + if c.HopInterval == 0 { + c.HopInterval = defaultClientHopIntervalSec * time.Second + } + if c.ALPN == "" { + c.ALPN = defaultALPN + } + if c.ReceiveWindowConn == 0 { + c.ReceiveWindowConn = defaultStreamReceiveWindow + } + if c.ReceiveWindow == 0 { + c.ReceiveWindow = defaultConnectionReceiveWindow + } + if c.IdleTimeout == 0 { + c.IdleTimeout = defaultClientIdleTimeoutSec * time.Second + } +} + +// NewClient creates a new Hysteria client. +func NewClient(config ClientConfig) (Client, error) { + // Fill in default values + config.fill() + // TLS config + var tlsConfig *tls.Config + if config.TLSConfig != nil { + tlsConfig = config.TLSConfig + } else { + tlsConfig = &tls.Config{ + NextProtos: []string{config.ALPN}, + ServerName: config.ServerName, + InsecureSkipVerify: config.Insecure, + RootCAs: config.RootCAs, + MinVersion: tls.VersionTLS13, + } + } + // QUIC config + var quicConfig *quic.Config + if config.QUICConfig != nil { + quicConfig = config.QUICConfig + } else { + quicConfig = &quic.Config{ + InitialStreamReceiveWindow: config.ReceiveWindowConn, + MaxStreamReceiveWindow: config.ReceiveWindowConn, + InitialConnectionReceiveWindow: config.ReceiveWindow, + MaxConnectionReceiveWindow: config.ReceiveWindow, + HandshakeIdleTimeout: config.HandshakeTimeout, + MaxIdleTimeout: config.IdleTimeout, + KeepAlivePeriod: config.IdleTimeout * 2 / 5, + DisablePathMTUDiscovery: config.DisableMTUDiscovery, + EnableDatagrams: true, + } + } + // Packet conn func + pff := clientPacketConnFuncFactoryMap[config.Protocol] + if pff == nil { + return nil, errors.New("unsupported protocol") + } + pf := pff(config.Obfs, config.HopInterval, config.ResolveFunc) + c, err := core.NewClient(config.ServerAddress, config.Auth, tlsConfig, quicConfig, pf, + config.SendBPS, config.RecvBPS, nil) + if err != nil { + return nil, err + } + return &clientImpl{c}, nil +} + +type clientImpl struct { + *core.Client +} + +func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { + return c.Client.DialTCP(addr) +} + +func (c *clientImpl) DialUDP() (HyUDPConn, error) { + conn, err := c.Client.DialUDP() + return HyUDPConn(conn), err +} + +func (c *clientImpl) Close() error { + return c.Client.Close() +} diff --git a/sdk/example/main.go b/sdk/example/main.go new file mode 100644 index 0000000..2e663b0 --- /dev/null +++ b/sdk/example/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "io" + + "github.com/apernet/hysteria/sdk" +) + +func main() { + config := sdk.ClientConfig{ + ServerAddress: "just.example.net:6677", + Protocol: sdk.ProtocolUDP, + Obfs: "password1234", + SendBPS: 524288, + RecvBPS: 524288, + } + client, err := sdk.NewClient(config) + if err != nil { + fmt.Println("NewClient:", err) + return + } + defer client.Close() + + conn, err := client.DialTCP("ipinfo.io:80") + if err != nil { + fmt.Println("DialTCP:", err) + return + } + defer conn.Close() + + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ipinfo.io\r\nConnection: close\r\n\r\n")) + if err != nil { + fmt.Println("Write:", err) + return + } + bs, err := io.ReadAll(conn) + if err != nil { + fmt.Println("ReadAll:", err) + return + } + fmt.Println(string(bs)) +} diff --git a/sdk/pktconns.go b/sdk/pktconns.go new file mode 100644 index 0000000..0254515 --- /dev/null +++ b/sdk/pktconns.go @@ -0,0 +1,119 @@ +package sdk + +import ( + "net" + "strings" + "time" + + "github.com/apernet/hysteria/pkg/pktconns" + "github.com/apernet/hysteria/pkg/pktconns/faketcp" + "github.com/apernet/hysteria/pkg/pktconns/obfs" + "github.com/apernet/hysteria/pkg/pktconns/udp" + "github.com/apernet/hysteria/pkg/pktconns/wechat" +) + +type ( + clientPacketConnFuncFactory func(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc +) + +var clientPacketConnFuncFactoryMap = map[Protocol]clientPacketConnFuncFactory{ + ProtocolUDP: newClientUDPConnFunc, + ProtocolWeChat: newClientWeChatConnFunc, + ProtocolFakeTCP: newClientFakeTCPConnFunc, +} + +func newClientUDPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { + if obfsPassword == "" { + return func(server string) (net.PacketConn, net.Addr, error) { + if isMultiPortAddr(server) { + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil) + } + sAddr, err := resolveFunc("udp", server) + if err != nil { + return nil, nil, err + } + udpConn, err := net.ListenUDP("udp", nil) + return udpConn, sAddr, err + } + } else { + return func(server string) (net.PacketConn, net.Addr, error) { + if isMultiPortAddr(server) { + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, ob) + } + sAddr, err := resolveFunc("udp", server) + if err != nil { + return nil, nil, err + } + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return udp.NewObfsUDPConn(udpConn, ob), sAddr, nil + } + } +} + +func newClientWeChatConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { + if obfsPassword == "" { + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := resolveFunc("udp", server) + if err != nil { + return nil, nil, err + } + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + return wechat.NewObfsWeChatUDPConn(udpConn, nil), sAddr, nil + } + } else { + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := resolveFunc("udp", server) + if err != nil { + return nil, nil, err + } + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return wechat.NewObfsWeChatUDPConn(udpConn, ob), sAddr, nil + } + } +} + +func newClientFakeTCPConnFunc(obfsPassword string, hopInterval time.Duration, resolveFunc ResolveFunc) pktconns.ClientPacketConnFunc { + if obfsPassword == "" { + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := resolveFunc("tcp", server) + if err != nil { + return nil, nil, err + } + fTCPConn, err := faketcp.Dial("tcp", server) + return fTCPConn, sAddr, err + } + } else { + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := resolveFunc("tcp", server) + if err != nil { + return nil, nil, err + } + fTCPConn, err := faketcp.Dial("tcp", server) + if err != nil { + return nil, nil, err + } + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return faketcp.NewObfsFakeTCPConn(fTCPConn, ob), sAddr, nil + } + } +} + +func isMultiPortAddr(addr string) bool { + _, portStr, err := net.SplitHostPort(addr) + if err == nil && (strings.Contains(portStr, ",") || strings.Contains(portStr, "-")) { + return true + } + return false +}