feat: client side ACL

This commit is contained in:
WoaShieShei 2024-01-31 11:12:03 +00:00
parent 6f88b9dcf7
commit 3a15aa522b
18 changed files with 389 additions and 133 deletions

View file

@ -58,7 +58,6 @@ type clientConfig struct {
QUIC clientConfigQUIC `mapstructure:"quic"` QUIC clientConfigQUIC `mapstructure:"quic"`
Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"` Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"`
ACL clientConfigACL `mapstructure:"acl"` ACL clientConfigACL `mapstructure:"acl"`
Resolver clientConfigResolver `mapstructure:"resolver"`
FastOpen bool `mapstructure:"fastOpen"` FastOpen bool `mapstructure:"fastOpen"`
Lazy bool `mapstructure:"lazy"` Lazy bool `mapstructure:"lazy"`
SOCKS5 *socks5Config `mapstructure:"socks5"` SOCKS5 *socks5Config `mapstructure:"socks5"`
@ -145,38 +144,6 @@ type clientConfigOutboundHTTP struct {
Insecure bool `mapstructure:"insecure"` Insecure bool `mapstructure:"insecure"`
} }
type clientConfigResolverTCP struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
}
type clientConfigResolverUDP struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
}
type clientConfigResolverTLS struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
}
type clientConfigResolverHTTPS struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
}
type clientConfigResolver struct {
Type string `mapstructure:"type"`
TCP clientConfigResolverTCP `mapstructure:"tcp"`
UDP clientConfigResolverUDP `mapstructure:"udp"`
TLS clientConfigResolverTLS `mapstructure:"tls"`
HTTPS clientConfigResolverHTTPS `mapstructure:"https"`
}
type socks5Config struct { type socks5Config struct {
Listen string `mapstructure:"listen"` Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"` Username string `mapstructure:"username"`
@ -354,6 +321,76 @@ func (c *clientConfig) fillFastOpen(hyConfig *client.Config) error {
return nil return nil
} }
func (c *clientConfig) fillOutbounds(hyConfig *client.Config) error {
var obs []outbounds.OutboundEntry
if len(c.Outbounds) == 0 {
// The items for which 'outbound' is set to 'nil'
// will have their traffic proxied by remote hysteria server
obs = []outbounds.OutboundEntry{
{
Name: "direct",
Outbound: outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto),
},
{
Name: "default",
Outbound: nil,
},
}
} else {
obs = make([]outbounds.OutboundEntry, len(c.Outbounds))
for i, entry := range c.Outbounds {
if entry.Name == "" {
return configError{Field: "outbounds.name", Err: errors.New("empty outbound name")}
}
var ob outbounds.PluggableOutbound
var err error
switch strings.ToLower(entry.Type) {
case "default":
case "hysteria":
ob, err = nil, nil
case "direct":
ob, err = clientConfigOutboundDirectToOutbound(entry.Direct)
case "socks5":
ob, err = clientConfigOutboundSOCKS5ToOutbound(entry.SOCKS5)
case "http":
ob, err = clientConfigOutboundHTTPToOutbound(entry.HTTP)
default:
err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")}
}
if err != nil {
return err
}
obs[i] = outbounds.OutboundEntry{Name: entry.Name, Outbound: ob}
}
}
hyConfig.Outbounds = obs
return nil
}
func (c *clientConfig) fillACLs(hyConfig *client.Config) error {
// ACL
if c.ACL.File != "" && len(c.ACL.Inline) > 0 {
return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")}
}
hyConfig.GeoLoader = &utils.GeoLoader{
GeoIPFilename: c.ACL.GeoIP,
GeoSiteFilename: c.ACL.GeoSite,
UpdateInterval: c.ACL.GeoUpdateInterval,
DownloadFunc: geoDownloadFunc,
DownloadErrFunc: geoDownloadErrFunc,
}
if c.ACL.File != "" {
bs, err := os.ReadFile(c.ACL.File)
if err != nil {
return configError{Field: "acl", Err: errors.New("cannot load acl file")}
}
hyConfig.ACLs = string(bs)
} else if len(c.ACL.Inline) > 0 {
hyConfig.ACLs = strings.Join(c.ACL.Inline, "\n")
}
return nil
}
// URI generates a URI for sharing the config with others. // URI generates a URI for sharing the config with others.
// Note that only the bare minimum of information required to // Note that only the bare minimum of information required to
// connect to the server is included in the URI, specifically: // connect to the server is included in the URI, specifically:
@ -500,6 +537,8 @@ func (c *clientConfig) Config() (*client.Config, error) {
c.fillQUICConfig, c.fillQUICConfig,
c.fillBandwidthConfig, c.fillBandwidthConfig,
c.fillFastOpen, c.fillFastOpen,
c.fillOutbounds,
c.fillACLs,
} }
for _, f := range fillers { for _, f := range fillers {
if err := f(hyConfig); err != nil { if err := f(hyConfig); err != nil {

View file

@ -35,23 +35,25 @@ func (t *atomicTime) Get() time.Time {
} }
type sessionEntry struct { type sessionEntry struct {
HyConn client.HyUDPConn HyConn client.UDPConn
Last *atomicTime Last *atomicTime
Timeout bool // true if the session is closed due to timeout Timeout bool // true if the session is closed due to timeout
} }
func (e *sessionEntry) Feed(data []byte, addr string) error { func (e *sessionEntry) Feed(data []byte, addr string) error {
e.Last.Set(time.Now()) e.Last.Set(time.Now())
return e.HyConn.Send(data, addr) _, err := e.HyConn.WriteTo(data, addr)
return err
} }
func (e *sessionEntry) ReceiveLoop(pc net.PacketConn, addr net.Addr) error { func (e *sessionEntry) ReceiveLoop(pc net.PacketConn, addr net.Addr) error {
buf := make([]byte, udpBufferSize)
for { for {
data, _, err := e.HyConn.Receive() n, _, err := e.HyConn.ReadFrom(buf)
if err != nil { if err != nil {
return err return err
} }
_, err = pc.WriteTo(data, addr) _, err = pc.WriteTo(buf[:n], addr)
if err != nil { if err != nil {
return err return err
} }

View file

@ -24,7 +24,7 @@ func (c *mockHyClient) TCP(addr string) (net.Conn, error) {
return net.Dial("tcp", addr) return net.Dial("tcp", addr)
} }
func (c *mockHyClient) UDP() (client.HyUDPConn, error) { func (c *mockHyClient) UDP() (client.UDPConn, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@ -32,6 +32,14 @@ func (c *mockHyClient) Close() error {
return nil return nil
} }
func (rc *mockHyClient) Outbound() *client.Hy2ClientOutbound {
return nil
}
func (c *mockHyClient) Config() *client.Config {
return nil
}
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
// Start the server // Start the server
l, err := net.Listen("tcp", "127.0.0.1:18080") l, err := net.Listen("tcp", "127.0.0.1:18080")

View file

@ -220,7 +220,7 @@ func (s *Server) handleUDP(conn net.Conn, req *socks5.Request) {
closeErr = <-errChan closeErr = <-errChan
} }
func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error { func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.UDPConn) error {
var clientAddr *net.UDPAddr var clientAddr *net.UDPAddr
buf := make([]byte, udpBufferSize) buf := make([]byte, udpBufferSize)
// local -> remote // local -> remote
@ -243,8 +243,9 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
// Now that we know the client's address, we can start the // Now that we know the client's address, we can start the
// remote -> local direction. // remote -> local direction.
go func() { go func() {
buf := make([]byte, udpBufferSize)
for { for {
bs, from, err := hyUDP.Receive() n, from, err := hyUDP.ReadFrom(buf)
if err != nil { if err != nil {
// Close the UDP conn so that the local -> remote direction will exit // Close the UDP conn so that the local -> remote direction will exit
_ = udpConn.Close() _ = udpConn.Close()
@ -260,7 +261,7 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
// So we must remove it here. // So we must remove it here.
addr = addr[1:] addr = addr[1:]
} }
d := socks5.NewDatagram(atyp, addr, port, bs) d := socks5.NewDatagram(atyp, addr, port, buf[:n])
_, _ = udpConn.WriteToUDP(d.Bytes(), clientAddr) _, _ = udpConn.WriteToUDP(d.Bytes(), clientAddr)
} }
}() }()
@ -269,7 +270,7 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
continue continue
} }
// Send to remote // Send to remote
_ = hyUDP.Send(d.Data, d.Address()) _, _ = hyUDP.WriteTo(d.Data, d.Address())
} }
} }

View file

@ -68,7 +68,7 @@ func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
return return
} }
// Send the first packet // Send the first packet
err = hyConn.Send(initPkt, dstAddr.String()) _, err = hyConn.WriteTo(initPkt, dstAddr.String())
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
_ = hyConn.Close() _ = hyConn.Close()
@ -91,17 +91,18 @@ func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
}() }()
} }
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error { func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.UDPConn, dst string) error {
errChan := make(chan error, 2) errChan := make(chan error, 2)
// Local <- Remote // Local <- Remote
go func() { go func() {
buf := make([]byte, udpBufferSize)
for { for {
bs, _, err := hyConn.Receive() n, _, err := hyConn.ReadFrom(buf)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
_, err = conn.Write(bs) _, err = conn.Write(buf[:n])
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
@ -116,7 +117,7 @@ func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst s
_ = r.updateConnDeadline(conn) _ = r.updateConnDeadline(conn)
n, err := conn.Read(buf) n, err := conn.Read(buf)
if n > 0 { if n > 0 {
err := hyConn.Send(buf[:n], dst) _, err := hyConn.WriteTo(buf[:n], dst)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return

View file

@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/apernet/hysteria/core/client" "github.com/apernet/hysteria/core/client"
"github.com/apernet/hysteria/extras/outbounds"
) )
const ( const (
@ -55,7 +57,13 @@ func NewClientUpdateChecker(currentVersion, platform, architecture, channel stri
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
// Unfortunately HyClient doesn't support context for now // Unfortunately HyClient doesn't support context for now
return hyClient.TCP(addr) host, port, _ := net.SplitHostPort(addr)
portInt, _ := strconv.Atoi(port)
return hyClient.Outbound().TCP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
}, },
}, },
}, },

View file

@ -16,7 +16,7 @@ func (c *MockEchoHyClient) TCP(addr string) (net.Conn, error) {
}, nil }, nil
} }
func (c *MockEchoHyClient) UDP() (client.HyUDPConn, error) { func (c *MockEchoHyClient) UDP() (client.UDPConn, error) {
return &mockEchoUDPConn{ return &mockEchoUDPConn{
BufChan: make(chan mockEchoUDPPacket, 10), BufChan: make(chan mockEchoUDPPacket, 10),
}, nil }, nil
@ -26,6 +26,14 @@ func (c *MockEchoHyClient) Close() error {
return nil return nil
} }
func (rc *MockEchoHyClient) Outbound() *client.Hy2ClientOutbound {
return nil
}
func (c *MockEchoHyClient) Config() *client.Config {
return nil
}
type mockEchoTCPConn struct { type mockEchoTCPConn struct {
BufChan chan []byte BufChan chan []byte
} }
@ -83,21 +91,22 @@ type mockEchoUDPConn struct {
BufChan chan mockEchoUDPPacket BufChan chan mockEchoUDPPacket
} }
func (c *mockEchoUDPConn) Receive() ([]byte, string, error) { func (c *mockEchoUDPConn) ReadFrom(d []byte) (int, string, error) {
p := <-c.BufChan p := <-c.BufChan
if p.Data == nil { if p.Data == nil {
// EOF // EOF
return nil, "", io.EOF return 0, "", io.EOF
} }
return p.Data, p.Addr, nil copy(d, p.Data)
return len(d), p.Addr, nil
} }
func (c *mockEchoUDPConn) Send(bytes []byte, s string) error { func (c *mockEchoUDPConn) WriteTo(bytes []byte, s string) (int, error) {
c.BufChan <- mockEchoUDPPacket{ c.BufChan <- mockEchoUDPPacket{
Data: bytes, Data: bytes,
Addr: s, Addr: s,
} }
return nil return len(bytes), nil
} }
func (c *mockEchoUDPConn) Close() error { func (c *mockEchoUDPConn) Close() error {

View file

@ -3,6 +3,7 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -12,6 +13,7 @@ import (
"github.com/apernet/hysteria/core/internal/congestion" "github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils" "github.com/apernet/hysteria/core/internal/utils"
"github.com/apernet/hysteria/extras/outbounds"
"github.com/apernet/quic-go" "github.com/apernet/quic-go"
"github.com/apernet/quic-go/http3" "github.com/apernet/quic-go/http3"
@ -24,13 +26,9 @@ const (
type Client interface { type Client interface {
TCP(addr string) (net.Conn, error) TCP(addr string) (net.Conn, error)
UDP() (HyUDPConn, error) UDP() (UDPConn, error)
Close() error Config() *Config
} Outbound() *Hy2ClientOutbound
type HyUDPConn interface {
Receive() ([]byte, string, error)
Send([]byte, string) error
Close() error Close() error
} }
@ -43,8 +41,10 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) {
if err := config.verifyAndFill(); err != nil { if err := config.verifyAndFill(); err != nil {
return nil, nil, err return nil, nil, err
} }
c := &clientImpl{ c := &ClientImpl{
config: config, Hy2ClientOutbound{
config: config,
},
} }
info, err := c.connect() info, err := c.connect()
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) {
return c, info, nil return c, info, nil
} }
type clientImpl struct { type Hy2ClientOutbound struct {
config *Config config *Config
pktConn net.PacketConn pktConn net.PacketConn
@ -62,26 +62,26 @@ type clientImpl struct {
udpSM *udpSessionManager udpSM *udpSessionManager
} }
func (c *clientImpl) connect() (*HandshakeInfo, error) { func (ob *Hy2ClientOutbound) connect() (*HandshakeInfo, error) {
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) pktConn, err := ob.config.ConnFactory.New(ob.config.ServerAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Convert config to TLS config & QUIC config // Convert config to TLS config & QUIC config
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: c.config.TLSConfig.ServerName, ServerName: ob.config.TLSConfig.ServerName,
InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify, InsecureSkipVerify: ob.config.TLSConfig.InsecureSkipVerify,
VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate, VerifyPeerCertificate: ob.config.TLSConfig.VerifyPeerCertificate,
RootCAs: c.config.TLSConfig.RootCAs, RootCAs: ob.config.TLSConfig.RootCAs,
} }
quicConfig := &quic.Config{ quicConfig := &quic.Config{
InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow, InitialStreamReceiveWindow: ob.config.QUICConfig.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow, MaxStreamReceiveWindow: ob.config.QUICConfig.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow, InitialConnectionReceiveWindow: ob.config.QUICConfig.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow, MaxConnectionReceiveWindow: ob.config.QUICConfig.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout, MaxIdleTimeout: ob.config.QUICConfig.MaxIdleTimeout,
KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod, KeepAlivePeriod: ob.config.QUICConfig.KeepAlivePeriod,
DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery, DisablePathMTUDiscovery: ob.config.QUICConfig.DisablePathMTUDiscovery,
EnableDatagrams: true, EnableDatagrams: true,
} }
// Prepare RoundTripper // Prepare RoundTripper
@ -91,7 +91,7 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
QuicConfig: quicConfig, QuicConfig: quicConfig,
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg) qc, err := quic.DialEarly(ctx, pktConn, ob.config.ServerAddr, tlsCfg, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -110,8 +110,8 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
Header: make(http.Header), Header: make(http.Header),
} }
protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{ protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{
Auth: c.config.Auth, Auth: ob.config.Auth,
Rx: c.config.BandwidthConfig.MaxRx, Rx: ob.config.BandwidthConfig.MaxRx,
}) })
resp, err := rt.RoundTrip(req) resp, err := rt.RoundTrip(req)
if err != nil { if err != nil {
@ -136,9 +136,9 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
} else { } else {
// actualTx = min(serverRx, clientTx) // actualTx = min(serverRx, clientTx)
actualTx = authResp.Rx actualTx = authResp.Rx
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { if actualTx == 0 || actualTx > ob.config.BandwidthConfig.MaxTx {
// Server doesn't have a limit, or our clientTx is smaller than serverRx // Server doesn't have a limit, or our clientTx is smaller than serverRx
actualTx = c.config.BandwidthConfig.MaxTx actualTx = ob.config.BandwidthConfig.MaxTx
} }
if actualTx > 0 { if actualTx > 0 {
congestion.UseBrutal(conn, actualTx) congestion.UseBrutal(conn, actualTx)
@ -149,10 +149,37 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
} }
_ = resp.Body.Close() _ = resp.Body.Close()
c.pktConn = pktConn if ob.config.Outbound == nil {
c.conn = conn var uOb outbounds.PluggableOutbound // "unified" outbound
for n, entry := range ob.config.Outbounds {
if entry.Outbound == nil {
ob.config.Outbounds[n].Outbound = ob
}
}
// we use the first entry of the outbound by default
uOb = ob.config.Outbounds[0].Outbound
// ACL
if ob.config.ACLs != "" {
acl, err := outbounds.NewACLEngineFromString(ob.config.ACLs, ob.config.Outbounds, ob.config.GeoLoader)
if err == nil {
uOb = acl
} else {
panic(err)
}
}
fmt.Println(ob.config.ACLs)
ob.config.Outbound = &PluggableClientOutboundAdapter{PluggableOutbound: uOb}
}
ob.pktConn = pktConn
ob.conn = conn
if authResp.UDPEnabled { if authResp.UDPEnabled {
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) ob.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
} }
return &HandshakeInfo{ return &HandshakeInfo{
UDPEnabled: authResp.UDPEnabled, UDPEnabled: authResp.UDPEnabled,
@ -161,33 +188,33 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
} }
// openStream wraps the stream with QStream, which handles Close() properly // openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) { func (ob *Hy2ClientOutbound) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream() stream, err := ob.conn.OpenStream()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &utils.QStream{Stream: stream}, nil return &utils.QStream{Stream: stream}, nil
} }
func (c *clientImpl) TCP(addr string) (net.Conn, error) { func (ob *Hy2ClientOutbound) TCP(reqAddr *outbounds.AddrEx) (net.Conn, error) {
stream, err := c.openStream() stream, err := ob.openStream()
if err != nil { if err != nil {
return nil, wrapIfConnectionClosed(err) return nil, wrapIfConnectionClosed(err)
} }
// Send request // Send request
err = protocol.WriteTCPRequest(stream, addr) err = protocol.WriteTCPRequest(stream, reqAddr.String())
if err != nil { if err != nil {
_ = stream.Close() _ = stream.Close()
return nil, wrapIfConnectionClosed(err) return nil, wrapIfConnectionClosed(err)
} }
if c.config.FastOpen { if ob.config.FastOpen {
// Don't wait for the response when fast open is enabled. // Don't wait for the response when fast open is enabled.
// Return the connection immediately, defer the response handling // Return the connection immediately, defer the response handling
// to the first Read() call. // to the first Read() call.
return &tcpConn{ return &tcpConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(), PseudoLocalAddr: ob.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(), PseudoRemoteAddr: ob.conn.RemoteAddr(),
Established: false, Established: false,
}, nil }, nil
} }
@ -203,25 +230,54 @@ func (c *clientImpl) TCP(addr string) (net.Conn, error) {
} }
return &tcpConn{ return &tcpConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(), PseudoLocalAddr: ob.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(), PseudoRemoteAddr: ob.conn.RemoteAddr(),
Established: true, Established: true,
}, nil }, nil
} }
func (c *clientImpl) UDP() (HyUDPConn, error) { func (ob *Hy2ClientOutbound) UDP(reqAddr *outbounds.AddrEx) (outbounds.UDPConn, error) {
if c.udpSM == nil { if ob.udpSM == nil {
return nil, coreErrs.DialError{Message: "UDP not enabled"} return nil, coreErrs.DialError{Message: "UDP not enabled"}
} }
return c.udpSM.NewUDP() return ob.udpSM.NewUDP()
} }
func (c *clientImpl) Close() error { func (ob *Hy2ClientOutbound) Close() error {
_ = c.conn.CloseWithError(closeErrCodeOK, "") _ = ob.conn.CloseWithError(closeErrCodeOK, "")
_ = c.pktConn.Close() _ = ob.pktConn.Close()
return nil return nil
} }
type ClientImpl struct {
ob Hy2ClientOutbound
}
// Outbound implements Client.
func (c *ClientImpl) Outbound() *Hy2ClientOutbound {
return &c.ob
}
func (c *ClientImpl) connect() (*HandshakeInfo, error) {
return c.ob.connect()
}
func (c *ClientImpl) Config() *Config {
return c.ob.config
}
func (c *ClientImpl) TCP(addr string) (net.Conn, error) {
return c.ob.config.Outbound.TCP(addr)
}
func (c *ClientImpl) UDP() (UDPConn, error) {
return c.ob.config.Outbound.UDP("localhost:0")
}
func (c *ClientImpl) Close() error {
return c.ob.Close()
}
// wrapIfConnectionClosed checks if the error returned by quic-go // wrapIfConnectionClosed checks if the error returned by quic-go
// indicates that the QUIC connection has been permanently closed, // indicates that the QUIC connection has been permanently closed,
// and if so, wraps the error with coreErrs.ClosedError. // and if so, wraps the error with coreErrs.ClosedError.

View file

@ -3,19 +3,77 @@ package client
import ( import (
"crypto/x509" "crypto/x509"
"net" "net"
"strconv"
"time" "time"
"github.com/apernet/hysteria/core/errors" "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/pmtud" "github.com/apernet/hysteria/core/internal/pmtud"
"github.com/apernet/hysteria/extras/outbounds"
"github.com/apernet/hysteria/extras/outbounds/acl"
) )
const ( const (
udpBufferSize = 4096
defaultStreamReceiveWindow = 8388608 // 8MB defaultStreamReceiveWindow = 8388608 // 8MB
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
defaultMaxIdleTimeout = 30 * time.Second defaultMaxIdleTimeout = 30 * time.Second
defaultKeepAlivePeriod = 10 * time.Second defaultKeepAlivePeriod = 10 * time.Second
) )
// Outbound provides the implementation of how the server should connect to remote servers.
// Although UDP includes a reqAddr, the implementation does not necessarily have to use it
// to make a "connected" UDP connection that does not accept packets from other addresses.
// In fact, the default implementation simply uses net.ListenUDP for a "full-cone" behavior.
type Outbound interface {
TCP(reqAddr string) (net.Conn, error)
UDP(reqAddr string) (UDPConn, error)
}
// UDPConn is like net.PacketConn, but uses string for addresses.
type UDPConn interface {
ReadFrom(b []byte) (int, string, error)
WriteTo(b []byte, addr string) (int, error)
Close() error
}
type PluggableClientOutboundAdapter struct {
outbounds.PluggableOutbound
}
func (a *PluggableClientOutboundAdapter) TCP(reqAddr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return a.PluggableOutbound.TCP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
}
func (a *PluggableClientOutboundAdapter) UDP(reqAddr string) (UDPConn, error) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
conn, err := a.PluggableOutbound.UDP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
if err != nil {
return nil, err
}
return &outbounds.UdpConnAdapter{UDPConn: conn}, nil
}
type Config struct { type Config struct {
ConnFactory ConnFactory ConnFactory ConnFactory
ServerAddr net.Addr ServerAddr net.Addr
@ -24,6 +82,10 @@ type Config struct {
QUICConfig QUICConfig QUICConfig QUICConfig
BandwidthConfig BandwidthConfig BandwidthConfig BandwidthConfig
FastOpen bool FastOpen bool
Outbound Outbound
Outbounds []outbounds.OutboundEntry
GeoLoader acl.GeoLoader
ACLs string
filled bool // whether the fields have been verified and filled filled bool // whether the fields have been verified and filled
} }

View file

@ -18,6 +18,14 @@ type reconnectableClientImpl struct {
closed bool // permanent close closed bool // permanent close
} }
func (rc *reconnectableClientImpl) Outbound() *Hy2ClientOutbound {
return rc.client.Outbound()
}
func (rc *reconnectableClientImpl) Config() *Config {
return rc.client.Config()
}
// NewReconnectableClient creates a reconnectable client. // NewReconnectableClient creates a reconnectable client.
// If lazy is true, the client will not connect until the first call to TCP() or UDP(). // If lazy is true, the client will not connect until the first call to TCP() or UDP().
// We use a function for config mainly to delay config evaluation // We use a function for config mainly to delay config evaluation
@ -99,13 +107,13 @@ func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) {
} }
} }
func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { func (rc *reconnectableClientImpl) UDP() (UDPConn, error) {
if c, err := rc.clientDo(func(client Client) (interface{}, error) { if c, err := rc.clientDo(func(client Client) (interface{}, error) {
return client.UDP() return client.UDP()
}); err != nil { }); err != nil {
return nil, err return nil, err
} else { } else {
return c.(HyUDPConn), nil return c.(UDPConn), nil
} }
} }

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"io" "io"
"math/rand" "math/rand"
"net"
"strconv"
"sync" "sync"
"github.com/apernet/quic-go" "github.com/apernet/quic-go"
@ -11,6 +13,7 @@ import (
coreErrs "github.com/apernet/hysteria/core/errors" coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/frag" "github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/extras/outbounds"
) )
const ( const (
@ -32,6 +35,38 @@ type udpConn struct {
Closed bool Closed bool
} }
func (u *udpConn) ReadFrom(b []byte) (n int, src *outbounds.AddrEx, err error) {
dfData, addr, err := u.Receive()
if err != nil {
return 0, nil, err
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return 0, nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return 0, nil, err
}
src = &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
if err != nil {
return 0, src, err
}
n = copy(b, dfData)
return n, src, nil
}
func (u *udpConn) WriteTo(b []byte, dst *outbounds.AddrEx) (int, error) {
err := u.Send(b, dst.String())
if err != nil {
return 0, err
}
return len(b), nil
}
func (u *udpConn) Receive() ([]byte, string, error) { func (u *udpConn) Receive() ([]byte, string, error) {
for { for {
msg := <-u.ReceiveCh msg := <-u.ReceiveCh
@ -142,7 +177,7 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
} }
// NewUDP creates a new UDP session. // NewUDP creates a new UDP session.
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { func (m *udpSessionManager) NewUDP() (outbounds.UDPConn, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View file

@ -3,6 +3,8 @@ package client
import ( import (
"errors" "errors"
io2 "io" io2 "io"
"net"
"strconv"
"testing" "testing"
"time" "time"
@ -12,6 +14,7 @@ import (
coreErrs "github.com/apernet/hysteria/core/errors" coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/extras/outbounds"
) )
func TestUDPSessionManager(t *testing.T) { func TestUDPSessionManager(t *testing.T) {
@ -40,8 +43,16 @@ func TestUDPSessionManager(t *testing.T) {
Addr: "random.site.com:9000", Addr: "random.site.com:9000",
Data: []byte("hello friend"), Data: []byte("hello friend"),
} }
host, port, err := net.SplitHostPort(msg1.Addr)
assert.NoError(t, err)
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
addr := &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
io.EXPECT().SendMessage(mock.Anything, msg1).Return(nil).Once() io.EXPECT().SendMessage(mock.Anything, msg1).Return(nil).Once()
err = udpConn1.Send(msg1.Data, msg1.Addr) _, err = udpConn1.WriteTo(msg1.Data, addr)
assert.NoError(t, err) assert.NoError(t, err)
msg2 := &protocol.UDPMessage{ msg2 := &protocol.UDPMessage{
@ -52,8 +63,16 @@ func TestUDPSessionManager(t *testing.T) {
Addr: "another.site.org:8000", Addr: "another.site.org:8000",
Data: []byte("mr robot"), Data: []byte("mr robot"),
} }
host, port, err = net.SplitHostPort(msg2.Addr)
assert.NoError(t, err)
portInt, err = strconv.Atoi(port)
assert.NoError(t, err)
addr = &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
io.EXPECT().SendMessage(mock.Anything, msg2).Return(nil).Once() io.EXPECT().SendMessage(mock.Anything, msg2).Return(nil).Once()
err = udpConn2.Send(msg2.Data, msg2.Addr) _, err = udpConn2.WriteTo(msg2.Data, addr)
assert.NoError(t, err) assert.NoError(t, err)
respMsg1 := &protocol.UDPMessage{ respMsg1 := &protocol.UDPMessage{
@ -65,9 +84,10 @@ func TestUDPSessionManager(t *testing.T) {
Data: []byte("goodbye captain price"), Data: []byte("goodbye captain price"),
} }
receiveCh <- respMsg1 receiveCh <- respMsg1
data, addr, err := udpConn1.Receive() buf := make([]byte, udpBufferSize)
n, addr, err := udpConn1.ReadFrom(buf)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, data, respMsg1.Data) assert.Equal(t, buf[:n], respMsg1.Data)
assert.Equal(t, addr, respMsg1.Addr) assert.Equal(t, addr, respMsg1.Addr)
respMsg2 := &protocol.UDPMessage{ respMsg2 := &protocol.UDPMessage{
@ -79,9 +99,9 @@ func TestUDPSessionManager(t *testing.T) {
Data: []byte("white rose"), Data: []byte("white rose"),
} }
receiveCh <- respMsg2 receiveCh <- respMsg2
data, addr, err = udpConn2.Receive() n, addr, err = udpConn2.ReadFrom(buf)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, data, respMsg2.Data) assert.Equal(t, buf[:n], respMsg2.Data)
assert.Equal(t, addr, respMsg2.Addr) assert.Equal(t, addr, respMsg2.Addr)
respMsg3 := &protocol.UDPMessage{ respMsg3 := &protocol.UDPMessage{
@ -98,7 +118,8 @@ func TestUDPSessionManager(t *testing.T) {
// Test close UDP connection unblocks Receive() // Test close UDP connection unblocks Receive()
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
_, _, err := udpConn1.Receive() buf := make([]byte, udpBufferSize)
_, _, err := udpConn1.ReadFrom(buf)
errChan <- err errChan <- err
}() }()
assert.NoError(t, udpConn1.Close()) assert.NoError(t, udpConn1.Close())
@ -107,7 +128,8 @@ func TestUDPSessionManager(t *testing.T) {
// Test close IO unblocks Receive() and blocks new UDP creation // Test close IO unblocks Receive() and blocks new UDP creation
errChan = make(chan error, 1) errChan = make(chan error, 1)
go func() { go func() {
_, _, err := udpConn2.Receive() buf := make([]byte, udpBufferSize)
_, _, err := udpConn2.ReadFrom(buf)
errChan <- err errChan <- err
}() }()
close(receiveCh) close(receiveCh)

View file

@ -148,7 +148,7 @@ func TestClientServerUDPIdleTimeout(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Client sends 4 packets // Client sends 4 packets
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
err = cu.Send([]byte("happy"), addr) _, err = cu.WriteTo([]byte("happy"), addr)
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
@ -159,10 +159,11 @@ func TestClientServerUDPIdleTimeout(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
}() }()
buf := make([]byte, udpBufferSize)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
bs, rAddr, err := cu.Receive() n, rAddr, err := cu.ReadFrom(buf)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "sad", string(bs)) assert.Equal(t, "sad", string(buf[:n]))
assert.Equal(t, addr, rAddr) assert.Equal(t, addr, rAddr)
} }
// Now we wait for 3 seconds, the server should close the UDP session. // Now we wait for 3 seconds, the server should close the UDP session.

View file

@ -173,12 +173,13 @@ func TestClientServerUDPEcho(t *testing.T) {
defer conn.Close() defer conn.Close()
// Send and receive data // Send and receive data
buf := make([]byte, udpBufferSize)
sData := []byte("hello world") sData := []byte("hello world")
err = conn.Send(sData, echoAddr) _, err = conn.WriteTo(sData, echoAddr)
assert.NoError(t, err) assert.NoError(t, err)
rData, rAddr, err := conn.Receive() n, rAddr, err := conn.ReadFrom(buf)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, sData, rData) assert.Equal(t, sData, buf[:n])
assert.Equal(t, echoAddr, rAddr) assert.Equal(t, echoAddr, rAddr)
} }

View file

@ -63,7 +63,7 @@ func (s *tcpStressor) Run(t *testing.T) {
} }
type udpStressor struct { type udpStressor struct {
ListenFunc func() (client.HyUDPConn, error) ListenFunc func() (client.UDPConn, error)
ServerAddr string ServerAddr string
Size int Size int
Count int Count int
@ -100,19 +100,20 @@ func (s *udpStressor) Run(t *testing.T) {
// Sending routine // Sending routine
for i := 0; i < s.Count; i++ { for i := 0; i < s.Count; i++ {
_ = limiter.WaitN(context.Background(), len(sData)) _ = limiter.WaitN(context.Background(), len(sData))
_ = conn.Send(sData, s.ServerAddr) _, _ = conn.WriteTo(sData, s.ServerAddr)
} }
}() }()
minCount := s.Count * 8 / 10 // Tolerate 20% packet loss minCount := s.Count * 8 / 10 // Tolerate 20% packet loss
buf := make([]byte, udpBufferSize)
for i := 0; i < minCount; i++ { for i := 0; i < minCount; i++ {
rData, _, err := conn.Receive() n, _, err := conn.ReadFrom(buf)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
if len(rData) != len(sData) { if len(buf[:n]) != len(sData) {
errChan <- fmt.Errorf("incomplete data received: %d/%d bytes", len(rData), len(sData)) errChan <- fmt.Errorf("incomplete data received: %d/%d bytes", len(buf[:n]), len(sData))
return return
} }
} }

View file

@ -145,25 +145,26 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) {
// Client writes to server // Client writes to server
trafficLogger.EXPECT().Log("nobody", uint64(9), uint64(0)).Return(true).Once() trafficLogger.EXPECT().Log("nobody", uint64(9), uint64(0)).Return(true).Once()
sobConn.EXPECT().WriteTo([]byte("small sad"), addr).Return(9, nil).Once() sobConn.EXPECT().WriteTo([]byte("small sad"), addr).Return(9, nil).Once()
err = conn.Send([]byte("small sad"), addr) _, err = conn.WriteTo([]byte("small sad"), addr)
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(1 * time.Second) // Need some time for the server to receive the data time.Sleep(1 * time.Second) // Need some time for the server to receive the data
buf := make([]byte, udpBufferSize)
// Client reads from server // Client reads from server
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(7)).Return(true).Once() trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(7)).Return(true).Once()
sobConnCh <- []byte("big mad") sobConnCh <- []byte("big mad")
bs, rAddr, err := conn.Receive() n, rAddr, err := conn.ReadFrom(buf)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, rAddr, addr) assert.Equal(t, rAddr, addr)
assert.Equal(t, "big mad", string(bs)) assert.Equal(t, "big mad", string(buf[:n]))
// Client reads from server again but blocked // Client reads from server again but blocked
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(4)).Return(false).Once() trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(4)).Return(false).Once()
sobConnCh <- []byte("nope") sobConnCh <- []byte("nope")
bs, rAddr, err = conn.Receive() n, rAddr, err = conn.ReadFrom(buf)
assert.Equal(t, err, io.EOF) assert.Equal(t, err, io.EOF)
assert.Empty(t, rAddr) assert.Empty(t, rAddr)
assert.Empty(t, bs) assert.Empty(t, buf[:n])
// The client should be disconnected // The client should be disconnected
_, err = c.UDP() _, err = c.UDP()

View file

@ -11,6 +11,7 @@ import (
// This file provides utilities for the integration tests. // This file provides utilities for the integration tests.
const ( const (
udpBufferSize = 4096
testCertFile = "test.crt" testCertFile = "test.crt"
testKeyFile = "test.key" testKeyFile = "test.key"
) )

View file

@ -97,14 +97,14 @@ func (a *PluggableOutboundAdapter) UDP(reqAddr string) (server.UDPConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &udpConnAdapter{conn}, nil return &UdpConnAdapter{conn}, nil
} }
type udpConnAdapter struct { type UdpConnAdapter struct {
UDPConn UDPConn
} }
func (u *udpConnAdapter) ReadFrom(b []byte) (int, string, error) { func (u *UdpConnAdapter) ReadFrom(b []byte) (int, string, error) {
n, addr, err := u.UDPConn.ReadFrom(b) n, addr, err := u.UDPConn.ReadFrom(b)
if addr != nil { if addr != nil {
return n, addr.String(), err return n, addr.String(), err
@ -113,7 +113,7 @@ func (u *udpConnAdapter) ReadFrom(b []byte) (int, string, error) {
} }
} }
func (u *udpConnAdapter) WriteTo(b []byte, addr string) (int, error) { func (u *UdpConnAdapter) WriteTo(b []byte, addr string) (int, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return 0, err return 0, err
@ -128,6 +128,6 @@ func (u *udpConnAdapter) WriteTo(b []byte, addr string) (int, error) {
}) })
} }
func (u *udpConnAdapter) Close() error { func (u *UdpConnAdapter) Close() error {
return u.UDPConn.Close() return u.UDPConn.Close()
} }