feat: client side ACL

This commit is contained in:
WoaShieShei 2024-01-31 11:12:03 +00:00
parent 6f88b9dcf7
commit 3a15aa522b
18 changed files with 389 additions and 133 deletions

View file

@ -58,7 +58,6 @@ type clientConfig struct {
QUIC clientConfigQUIC `mapstructure:"quic"`
Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"`
ACL clientConfigACL `mapstructure:"acl"`
Resolver clientConfigResolver `mapstructure:"resolver"`
FastOpen bool `mapstructure:"fastOpen"`
Lazy bool `mapstructure:"lazy"`
SOCKS5 *socks5Config `mapstructure:"socks5"`
@ -145,38 +144,6 @@ type clientConfigOutboundHTTP struct {
Insecure bool `mapstructure:"insecure"`
}
type clientConfigResolverTCP struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
}
type clientConfigResolverUDP struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
}
type clientConfigResolverTLS struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
}
type clientConfigResolverHTTPS struct {
Addr string `mapstructure:"addr"`
Timeout time.Duration `mapstructure:"timeout"`
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
}
type clientConfigResolver struct {
Type string `mapstructure:"type"`
TCP clientConfigResolverTCP `mapstructure:"tcp"`
UDP clientConfigResolverUDP `mapstructure:"udp"`
TLS clientConfigResolverTLS `mapstructure:"tls"`
HTTPS clientConfigResolverHTTPS `mapstructure:"https"`
}
type socks5Config struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
@ -354,6 +321,76 @@ func (c *clientConfig) fillFastOpen(hyConfig *client.Config) error {
return nil
}
func (c *clientConfig) fillOutbounds(hyConfig *client.Config) error {
var obs []outbounds.OutboundEntry
if len(c.Outbounds) == 0 {
// The items for which 'outbound' is set to 'nil'
// will have their traffic proxied by remote hysteria server
obs = []outbounds.OutboundEntry{
{
Name: "direct",
Outbound: outbounds.NewDirectOutboundSimple(outbounds.DirectOutboundModeAuto),
},
{
Name: "default",
Outbound: nil,
},
}
} else {
obs = make([]outbounds.OutboundEntry, len(c.Outbounds))
for i, entry := range c.Outbounds {
if entry.Name == "" {
return configError{Field: "outbounds.name", Err: errors.New("empty outbound name")}
}
var ob outbounds.PluggableOutbound
var err error
switch strings.ToLower(entry.Type) {
case "default":
case "hysteria":
ob, err = nil, nil
case "direct":
ob, err = clientConfigOutboundDirectToOutbound(entry.Direct)
case "socks5":
ob, err = clientConfigOutboundSOCKS5ToOutbound(entry.SOCKS5)
case "http":
ob, err = clientConfigOutboundHTTPToOutbound(entry.HTTP)
default:
err = configError{Field: "outbounds.type", Err: errors.New("unsupported outbound type")}
}
if err != nil {
return err
}
obs[i] = outbounds.OutboundEntry{Name: entry.Name, Outbound: ob}
}
}
hyConfig.Outbounds = obs
return nil
}
func (c *clientConfig) fillACLs(hyConfig *client.Config) error {
// ACL
if c.ACL.File != "" && len(c.ACL.Inline) > 0 {
return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")}
}
hyConfig.GeoLoader = &utils.GeoLoader{
GeoIPFilename: c.ACL.GeoIP,
GeoSiteFilename: c.ACL.GeoSite,
UpdateInterval: c.ACL.GeoUpdateInterval,
DownloadFunc: geoDownloadFunc,
DownloadErrFunc: geoDownloadErrFunc,
}
if c.ACL.File != "" {
bs, err := os.ReadFile(c.ACL.File)
if err != nil {
return configError{Field: "acl", Err: errors.New("cannot load acl file")}
}
hyConfig.ACLs = string(bs)
} else if len(c.ACL.Inline) > 0 {
hyConfig.ACLs = strings.Join(c.ACL.Inline, "\n")
}
return nil
}
// URI generates a URI for sharing the config with others.
// Note that only the bare minimum of information required to
// connect to the server is included in the URI, specifically:
@ -500,6 +537,8 @@ func (c *clientConfig) Config() (*client.Config, error) {
c.fillQUICConfig,
c.fillBandwidthConfig,
c.fillFastOpen,
c.fillOutbounds,
c.fillACLs,
}
for _, f := range fillers {
if err := f(hyConfig); err != nil {

View file

@ -35,23 +35,25 @@ func (t *atomicTime) Get() time.Time {
}
type sessionEntry struct {
HyConn client.HyUDPConn
HyConn client.UDPConn
Last *atomicTime
Timeout bool // true if the session is closed due to timeout
}
func (e *sessionEntry) Feed(data []byte, addr string) error {
e.Last.Set(time.Now())
return e.HyConn.Send(data, addr)
_, err := e.HyConn.WriteTo(data, addr)
return err
}
func (e *sessionEntry) ReceiveLoop(pc net.PacketConn, addr net.Addr) error {
buf := make([]byte, udpBufferSize)
for {
data, _, err := e.HyConn.Receive()
n, _, err := e.HyConn.ReadFrom(buf)
if err != nil {
return err
}
_, err = pc.WriteTo(data, addr)
_, err = pc.WriteTo(buf[:n], addr)
if err != nil {
return err
}

View file

@ -24,7 +24,7 @@ func (c *mockHyClient) TCP(addr string) (net.Conn, error) {
return net.Dial("tcp", addr)
}
func (c *mockHyClient) UDP() (client.HyUDPConn, error) {
func (c *mockHyClient) UDP() (client.UDPConn, error) {
return nil, errors.New("not implemented")
}
@ -32,6 +32,14 @@ func (c *mockHyClient) Close() error {
return nil
}
func (rc *mockHyClient) Outbound() *client.Hy2ClientOutbound {
return nil
}
func (c *mockHyClient) Config() *client.Config {
return nil
}
func TestServer(t *testing.T) {
// Start the server
l, err := net.Listen("tcp", "127.0.0.1:18080")

View file

@ -220,7 +220,7 @@ func (s *Server) handleUDP(conn net.Conn, req *socks5.Request) {
closeErr = <-errChan
}
func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.UDPConn) error {
var clientAddr *net.UDPAddr
buf := make([]byte, udpBufferSize)
// local -> remote
@ -243,8 +243,9 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
// Now that we know the client's address, we can start the
// remote -> local direction.
go func() {
buf := make([]byte, udpBufferSize)
for {
bs, from, err := hyUDP.Receive()
n, from, err := hyUDP.ReadFrom(buf)
if err != nil {
// Close the UDP conn so that the local -> remote direction will exit
_ = udpConn.Close()
@ -260,7 +261,7 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
// So we must remove it here.
addr = addr[1:]
}
d := socks5.NewDatagram(atyp, addr, port, bs)
d := socks5.NewDatagram(atyp, addr, port, buf[:n])
_, _ = udpConn.WriteToUDP(d.Bytes(), clientAddr)
}
}()
@ -269,7 +270,7 @@ func (s *Server) udpServer(udpConn *net.UDPConn, hyUDP client.HyUDPConn) error {
continue
}
// Send to remote
_ = hyUDP.Send(d.Data, d.Address())
_, _ = hyUDP.WriteTo(d.Data, d.Address())
}
}

View file

@ -68,7 +68,7 @@ func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
return
}
// Send the first packet
err = hyConn.Send(initPkt, dstAddr.String())
_, err = hyConn.WriteTo(initPkt, dstAddr.String())
if err != nil {
_ = conn.Close()
_ = hyConn.Close()
@ -91,17 +91,18 @@ func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
}()
}
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error {
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.UDPConn, dst string) error {
errChan := make(chan error, 2)
// Local <- Remote
go func() {
buf := make([]byte, udpBufferSize)
for {
bs, _, err := hyConn.Receive()
n, _, err := hyConn.ReadFrom(buf)
if err != nil {
errChan <- err
return
}
_, err = conn.Write(bs)
_, err = conn.Write(buf[:n])
if err != nil {
errChan <- err
return
@ -116,7 +117,7 @@ func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst s
_ = r.updateConnDeadline(conn)
n, err := conn.Read(buf)
if n > 0 {
err := hyConn.Send(buf[:n], dst)
_, err := hyConn.WriteTo(buf[:n], dst)
if err != nil {
errChan <- err
return

View file

@ -6,9 +6,11 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"time"
"github.com/apernet/hysteria/core/client"
"github.com/apernet/hysteria/extras/outbounds"
)
const (
@ -55,7 +57,13 @@ func NewClientUpdateChecker(currentVersion, platform, architecture, channel stri
Transport: &http.Transport{
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
// Unfortunately HyClient doesn't support context for now
return hyClient.TCP(addr)
host, port, _ := net.SplitHostPort(addr)
portInt, _ := strconv.Atoi(port)
return hyClient.Outbound().TCP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
},
},
},

View file

@ -16,7 +16,7 @@ func (c *MockEchoHyClient) TCP(addr string) (net.Conn, error) {
}, nil
}
func (c *MockEchoHyClient) UDP() (client.HyUDPConn, error) {
func (c *MockEchoHyClient) UDP() (client.UDPConn, error) {
return &mockEchoUDPConn{
BufChan: make(chan mockEchoUDPPacket, 10),
}, nil
@ -26,6 +26,14 @@ func (c *MockEchoHyClient) Close() error {
return nil
}
func (rc *MockEchoHyClient) Outbound() *client.Hy2ClientOutbound {
return nil
}
func (c *MockEchoHyClient) Config() *client.Config {
return nil
}
type mockEchoTCPConn struct {
BufChan chan []byte
}
@ -83,21 +91,22 @@ type mockEchoUDPConn struct {
BufChan chan mockEchoUDPPacket
}
func (c *mockEchoUDPConn) Receive() ([]byte, string, error) {
func (c *mockEchoUDPConn) ReadFrom(d []byte) (int, string, error) {
p := <-c.BufChan
if p.Data == nil {
// EOF
return nil, "", io.EOF
return 0, "", io.EOF
}
return p.Data, p.Addr, nil
copy(d, p.Data)
return len(d), p.Addr, nil
}
func (c *mockEchoUDPConn) Send(bytes []byte, s string) error {
func (c *mockEchoUDPConn) WriteTo(bytes []byte, s string) (int, error) {
c.BufChan <- mockEchoUDPPacket{
Data: bytes,
Addr: s,
}
return nil
return len(bytes), nil
}
func (c *mockEchoUDPConn) Close() error {

View file

@ -3,6 +3,7 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
@ -12,6 +13,7 @@ import (
"github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils"
"github.com/apernet/hysteria/extras/outbounds"
"github.com/apernet/quic-go"
"github.com/apernet/quic-go/http3"
@ -24,13 +26,9 @@ const (
type Client interface {
TCP(addr string) (net.Conn, error)
UDP() (HyUDPConn, error)
Close() error
}
type HyUDPConn interface {
Receive() ([]byte, string, error)
Send([]byte, string) error
UDP() (UDPConn, error)
Config() *Config
Outbound() *Hy2ClientOutbound
Close() error
}
@ -43,8 +41,10 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) {
if err := config.verifyAndFill(); err != nil {
return nil, nil, err
}
c := &clientImpl{
config: config,
c := &ClientImpl{
Hy2ClientOutbound{
config: config,
},
}
info, err := c.connect()
if err != nil {
@ -53,7 +53,7 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) {
return c, info, nil
}
type clientImpl struct {
type Hy2ClientOutbound struct {
config *Config
pktConn net.PacketConn
@ -62,26 +62,26 @@ type clientImpl struct {
udpSM *udpSessionManager
}
func (c *clientImpl) connect() (*HandshakeInfo, error) {
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
func (ob *Hy2ClientOutbound) connect() (*HandshakeInfo, error) {
pktConn, err := ob.config.ConnFactory.New(ob.config.ServerAddr)
if err != nil {
return nil, err
}
// Convert config to TLS config & QUIC config
tlsConfig := &tls.Config{
ServerName: c.config.TLSConfig.ServerName,
InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify,
VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate,
RootCAs: c.config.TLSConfig.RootCAs,
ServerName: ob.config.TLSConfig.ServerName,
InsecureSkipVerify: ob.config.TLSConfig.InsecureSkipVerify,
VerifyPeerCertificate: ob.config.TLSConfig.VerifyPeerCertificate,
RootCAs: ob.config.TLSConfig.RootCAs,
}
quicConfig := &quic.Config{
InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout,
KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod,
DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery,
InitialStreamReceiveWindow: ob.config.QUICConfig.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: ob.config.QUICConfig.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: ob.config.QUICConfig.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: ob.config.QUICConfig.MaxConnectionReceiveWindow,
MaxIdleTimeout: ob.config.QUICConfig.MaxIdleTimeout,
KeepAlivePeriod: ob.config.QUICConfig.KeepAlivePeriod,
DisablePathMTUDiscovery: ob.config.QUICConfig.DisablePathMTUDiscovery,
EnableDatagrams: true,
}
// Prepare RoundTripper
@ -91,7 +91,7 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
TLSClientConfig: tlsConfig,
QuicConfig: quicConfig,
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg)
qc, err := quic.DialEarly(ctx, pktConn, ob.config.ServerAddr, tlsCfg, cfg)
if err != nil {
return nil, err
}
@ -110,8 +110,8 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
Header: make(http.Header),
}
protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{
Auth: c.config.Auth,
Rx: c.config.BandwidthConfig.MaxRx,
Auth: ob.config.Auth,
Rx: ob.config.BandwidthConfig.MaxRx,
})
resp, err := rt.RoundTrip(req)
if err != nil {
@ -136,9 +136,9 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
} else {
// actualTx = min(serverRx, clientTx)
actualTx = authResp.Rx
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
if actualTx == 0 || actualTx > ob.config.BandwidthConfig.MaxTx {
// Server doesn't have a limit, or our clientTx is smaller than serverRx
actualTx = c.config.BandwidthConfig.MaxTx
actualTx = ob.config.BandwidthConfig.MaxTx
}
if actualTx > 0 {
congestion.UseBrutal(conn, actualTx)
@ -149,10 +149,37 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
}
_ = resp.Body.Close()
c.pktConn = pktConn
c.conn = conn
if ob.config.Outbound == nil {
var uOb outbounds.PluggableOutbound // "unified" outbound
for n, entry := range ob.config.Outbounds {
if entry.Outbound == nil {
ob.config.Outbounds[n].Outbound = ob
}
}
// we use the first entry of the outbound by default
uOb = ob.config.Outbounds[0].Outbound
// ACL
if ob.config.ACLs != "" {
acl, err := outbounds.NewACLEngineFromString(ob.config.ACLs, ob.config.Outbounds, ob.config.GeoLoader)
if err == nil {
uOb = acl
} else {
panic(err)
}
}
fmt.Println(ob.config.ACLs)
ob.config.Outbound = &PluggableClientOutboundAdapter{PluggableOutbound: uOb}
}
ob.pktConn = pktConn
ob.conn = conn
if authResp.UDPEnabled {
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
ob.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
}
return &HandshakeInfo{
UDPEnabled: authResp.UDPEnabled,
@ -161,33 +188,33 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) {
}
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream()
func (ob *Hy2ClientOutbound) openStream() (quic.Stream, error) {
stream, err := ob.conn.OpenStream()
if err != nil {
return nil, err
}
return &utils.QStream{Stream: stream}, nil
}
func (c *clientImpl) TCP(addr string) (net.Conn, error) {
stream, err := c.openStream()
func (ob *Hy2ClientOutbound) TCP(reqAddr *outbounds.AddrEx) (net.Conn, error) {
stream, err := ob.openStream()
if err != nil {
return nil, wrapIfConnectionClosed(err)
}
// Send request
err = protocol.WriteTCPRequest(stream, addr)
err = protocol.WriteTCPRequest(stream, reqAddr.String())
if err != nil {
_ = stream.Close()
return nil, wrapIfConnectionClosed(err)
}
if c.config.FastOpen {
if ob.config.FastOpen {
// Don't wait for the response when fast open is enabled.
// Return the connection immediately, defer the response handling
// to the first Read() call.
return &tcpConn{
Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(),
PseudoLocalAddr: ob.conn.LocalAddr(),
PseudoRemoteAddr: ob.conn.RemoteAddr(),
Established: false,
}, nil
}
@ -203,25 +230,54 @@ func (c *clientImpl) TCP(addr string) (net.Conn, error) {
}
return &tcpConn{
Orig: stream,
PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: c.conn.RemoteAddr(),
PseudoLocalAddr: ob.conn.LocalAddr(),
PseudoRemoteAddr: ob.conn.RemoteAddr(),
Established: true,
}, nil
}
func (c *clientImpl) UDP() (HyUDPConn, error) {
if c.udpSM == nil {
func (ob *Hy2ClientOutbound) UDP(reqAddr *outbounds.AddrEx) (outbounds.UDPConn, error) {
if ob.udpSM == nil {
return nil, coreErrs.DialError{Message: "UDP not enabled"}
}
return c.udpSM.NewUDP()
return ob.udpSM.NewUDP()
}
func (c *clientImpl) Close() error {
_ = c.conn.CloseWithError(closeErrCodeOK, "")
_ = c.pktConn.Close()
func (ob *Hy2ClientOutbound) Close() error {
_ = ob.conn.CloseWithError(closeErrCodeOK, "")
_ = ob.pktConn.Close()
return nil
}
type ClientImpl struct {
ob Hy2ClientOutbound
}
// Outbound implements Client.
func (c *ClientImpl) Outbound() *Hy2ClientOutbound {
return &c.ob
}
func (c *ClientImpl) connect() (*HandshakeInfo, error) {
return c.ob.connect()
}
func (c *ClientImpl) Config() *Config {
return c.ob.config
}
func (c *ClientImpl) TCP(addr string) (net.Conn, error) {
return c.ob.config.Outbound.TCP(addr)
}
func (c *ClientImpl) UDP() (UDPConn, error) {
return c.ob.config.Outbound.UDP("localhost:0")
}
func (c *ClientImpl) Close() error {
return c.ob.Close()
}
// wrapIfConnectionClosed checks if the error returned by quic-go
// indicates that the QUIC connection has been permanently closed,
// and if so, wraps the error with coreErrs.ClosedError.

View file

@ -3,19 +3,77 @@ package client
import (
"crypto/x509"
"net"
"strconv"
"time"
"github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/pmtud"
"github.com/apernet/hysteria/extras/outbounds"
"github.com/apernet/hysteria/extras/outbounds/acl"
)
const (
udpBufferSize = 4096
defaultStreamReceiveWindow = 8388608 // 8MB
defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB
defaultMaxIdleTimeout = 30 * time.Second
defaultKeepAlivePeriod = 10 * time.Second
)
// Outbound provides the implementation of how the server should connect to remote servers.
// Although UDP includes a reqAddr, the implementation does not necessarily have to use it
// to make a "connected" UDP connection that does not accept packets from other addresses.
// In fact, the default implementation simply uses net.ListenUDP for a "full-cone" behavior.
type Outbound interface {
TCP(reqAddr string) (net.Conn, error)
UDP(reqAddr string) (UDPConn, error)
}
// UDPConn is like net.PacketConn, but uses string for addresses.
type UDPConn interface {
ReadFrom(b []byte) (int, string, error)
WriteTo(b []byte, addr string) (int, error)
Close() error
}
type PluggableClientOutboundAdapter struct {
outbounds.PluggableOutbound
}
func (a *PluggableClientOutboundAdapter) TCP(reqAddr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return a.PluggableOutbound.TCP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
}
func (a *PluggableClientOutboundAdapter) UDP(reqAddr string) (UDPConn, error) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
conn, err := a.PluggableOutbound.UDP(&outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
})
if err != nil {
return nil, err
}
return &outbounds.UdpConnAdapter{UDPConn: conn}, nil
}
type Config struct {
ConnFactory ConnFactory
ServerAddr net.Addr
@ -24,6 +82,10 @@ type Config struct {
QUICConfig QUICConfig
BandwidthConfig BandwidthConfig
FastOpen bool
Outbound Outbound
Outbounds []outbounds.OutboundEntry
GeoLoader acl.GeoLoader
ACLs string
filled bool // whether the fields have been verified and filled
}

View file

@ -18,6 +18,14 @@ type reconnectableClientImpl struct {
closed bool // permanent close
}
func (rc *reconnectableClientImpl) Outbound() *Hy2ClientOutbound {
return rc.client.Outbound()
}
func (rc *reconnectableClientImpl) Config() *Config {
return rc.client.Config()
}
// NewReconnectableClient creates a reconnectable client.
// If lazy is true, the client will not connect until the first call to TCP() or UDP().
// We use a function for config mainly to delay config evaluation
@ -99,13 +107,13 @@ func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) {
}
}
func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) {
func (rc *reconnectableClientImpl) UDP() (UDPConn, error) {
if c, err := rc.clientDo(func(client Client) (interface{}, error) {
return client.UDP()
}); err != nil {
return nil, err
} else {
return c.(HyUDPConn), nil
return c.(UDPConn), nil
}
}

View file

@ -4,6 +4,8 @@ import (
"errors"
"io"
"math/rand"
"net"
"strconv"
"sync"
"github.com/apernet/quic-go"
@ -11,6 +13,7 @@ import (
coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/extras/outbounds"
)
const (
@ -32,6 +35,38 @@ type udpConn struct {
Closed bool
}
func (u *udpConn) ReadFrom(b []byte) (n int, src *outbounds.AddrEx, err error) {
dfData, addr, err := u.Receive()
if err != nil {
return 0, nil, err
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return 0, nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return 0, nil, err
}
src = &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
if err != nil {
return 0, src, err
}
n = copy(b, dfData)
return n, src, nil
}
func (u *udpConn) WriteTo(b []byte, dst *outbounds.AddrEx) (int, error) {
err := u.Send(b, dst.String())
if err != nil {
return 0, err
}
return len(b), nil
}
func (u *udpConn) Receive() ([]byte, string, error) {
for {
msg := <-u.ReceiveCh
@ -142,7 +177,7 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
}
// NewUDP creates a new UDP session.
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
func (m *udpSessionManager) NewUDP() (outbounds.UDPConn, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

View file

@ -3,6 +3,8 @@ package client
import (
"errors"
io2 "io"
"net"
"strconv"
"testing"
"time"
@ -12,6 +14,7 @@ import (
coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/extras/outbounds"
)
func TestUDPSessionManager(t *testing.T) {
@ -40,8 +43,16 @@ func TestUDPSessionManager(t *testing.T) {
Addr: "random.site.com:9000",
Data: []byte("hello friend"),
}
host, port, err := net.SplitHostPort(msg1.Addr)
assert.NoError(t, err)
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
addr := &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
io.EXPECT().SendMessage(mock.Anything, msg1).Return(nil).Once()
err = udpConn1.Send(msg1.Data, msg1.Addr)
_, err = udpConn1.WriteTo(msg1.Data, addr)
assert.NoError(t, err)
msg2 := &protocol.UDPMessage{
@ -52,8 +63,16 @@ func TestUDPSessionManager(t *testing.T) {
Addr: "another.site.org:8000",
Data: []byte("mr robot"),
}
host, port, err = net.SplitHostPort(msg2.Addr)
assert.NoError(t, err)
portInt, err = strconv.Atoi(port)
assert.NoError(t, err)
addr = &outbounds.AddrEx{
Host: host,
Port: uint16(portInt),
}
io.EXPECT().SendMessage(mock.Anything, msg2).Return(nil).Once()
err = udpConn2.Send(msg2.Data, msg2.Addr)
_, err = udpConn2.WriteTo(msg2.Data, addr)
assert.NoError(t, err)
respMsg1 := &protocol.UDPMessage{
@ -65,9 +84,10 @@ func TestUDPSessionManager(t *testing.T) {
Data: []byte("goodbye captain price"),
}
receiveCh <- respMsg1
data, addr, err := udpConn1.Receive()
buf := make([]byte, udpBufferSize)
n, addr, err := udpConn1.ReadFrom(buf)
assert.NoError(t, err)
assert.Equal(t, data, respMsg1.Data)
assert.Equal(t, buf[:n], respMsg1.Data)
assert.Equal(t, addr, respMsg1.Addr)
respMsg2 := &protocol.UDPMessage{
@ -79,9 +99,9 @@ func TestUDPSessionManager(t *testing.T) {
Data: []byte("white rose"),
}
receiveCh <- respMsg2
data, addr, err = udpConn2.Receive()
n, addr, err = udpConn2.ReadFrom(buf)
assert.NoError(t, err)
assert.Equal(t, data, respMsg2.Data)
assert.Equal(t, buf[:n], respMsg2.Data)
assert.Equal(t, addr, respMsg2.Addr)
respMsg3 := &protocol.UDPMessage{
@ -98,7 +118,8 @@ func TestUDPSessionManager(t *testing.T) {
// Test close UDP connection unblocks Receive()
errChan := make(chan error, 1)
go func() {
_, _, err := udpConn1.Receive()
buf := make([]byte, udpBufferSize)
_, _, err := udpConn1.ReadFrom(buf)
errChan <- err
}()
assert.NoError(t, udpConn1.Close())
@ -107,7 +128,8 @@ func TestUDPSessionManager(t *testing.T) {
// Test close IO unblocks Receive() and blocks new UDP creation
errChan = make(chan error, 1)
go func() {
_, _, err := udpConn2.Receive()
buf := make([]byte, udpBufferSize)
_, _, err := udpConn2.ReadFrom(buf)
errChan <- err
}()
close(receiveCh)

View file

@ -148,7 +148,7 @@ func TestClientServerUDPIdleTimeout(t *testing.T) {
assert.NoError(t, err)
// Client sends 4 packets
for i := 0; i < 4; i++ {
err = cu.Send([]byte("happy"), addr)
_, err = cu.WriteTo([]byte("happy"), addr)
assert.NoError(t, err)
time.Sleep(1 * time.Second)
}
@ -159,10 +159,11 @@ func TestClientServerUDPIdleTimeout(t *testing.T) {
time.Sleep(1 * time.Second)
}
}()
buf := make([]byte, udpBufferSize)
for i := 0; i < 4; i++ {
bs, rAddr, err := cu.Receive()
n, rAddr, err := cu.ReadFrom(buf)
assert.NoError(t, err)
assert.Equal(t, "sad", string(bs))
assert.Equal(t, "sad", string(buf[:n]))
assert.Equal(t, addr, rAddr)
}
// Now we wait for 3 seconds, the server should close the UDP session.

View file

@ -173,12 +173,13 @@ func TestClientServerUDPEcho(t *testing.T) {
defer conn.Close()
// Send and receive data
buf := make([]byte, udpBufferSize)
sData := []byte("hello world")
err = conn.Send(sData, echoAddr)
_, err = conn.WriteTo(sData, echoAddr)
assert.NoError(t, err)
rData, rAddr, err := conn.Receive()
n, rAddr, err := conn.ReadFrom(buf)
assert.NoError(t, err)
assert.Equal(t, sData, rData)
assert.Equal(t, sData, buf[:n])
assert.Equal(t, echoAddr, rAddr)
}

View file

@ -63,7 +63,7 @@ func (s *tcpStressor) Run(t *testing.T) {
}
type udpStressor struct {
ListenFunc func() (client.HyUDPConn, error)
ListenFunc func() (client.UDPConn, error)
ServerAddr string
Size int
Count int
@ -100,19 +100,20 @@ func (s *udpStressor) Run(t *testing.T) {
// Sending routine
for i := 0; i < s.Count; i++ {
_ = limiter.WaitN(context.Background(), len(sData))
_ = conn.Send(sData, s.ServerAddr)
_, _ = conn.WriteTo(sData, s.ServerAddr)
}
}()
minCount := s.Count * 8 / 10 // Tolerate 20% packet loss
buf := make([]byte, udpBufferSize)
for i := 0; i < minCount; i++ {
rData, _, err := conn.Receive()
n, _, err := conn.ReadFrom(buf)
if err != nil {
errChan <- err
return
}
if len(rData) != len(sData) {
errChan <- fmt.Errorf("incomplete data received: %d/%d bytes", len(rData), len(sData))
if len(buf[:n]) != len(sData) {
errChan <- fmt.Errorf("incomplete data received: %d/%d bytes", len(buf[:n]), len(sData))
return
}
}

View file

@ -145,25 +145,26 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) {
// Client writes to server
trafficLogger.EXPECT().Log("nobody", uint64(9), uint64(0)).Return(true).Once()
sobConn.EXPECT().WriteTo([]byte("small sad"), addr).Return(9, nil).Once()
err = conn.Send([]byte("small sad"), addr)
_, err = conn.WriteTo([]byte("small sad"), addr)
assert.NoError(t, err)
time.Sleep(1 * time.Second) // Need some time for the server to receive the data
buf := make([]byte, udpBufferSize)
// Client reads from server
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(7)).Return(true).Once()
sobConnCh <- []byte("big mad")
bs, rAddr, err := conn.Receive()
n, rAddr, err := conn.ReadFrom(buf)
assert.NoError(t, err)
assert.Equal(t, rAddr, addr)
assert.Equal(t, "big mad", string(bs))
assert.Equal(t, "big mad", string(buf[:n]))
// Client reads from server again but blocked
trafficLogger.EXPECT().Log("nobody", uint64(0), uint64(4)).Return(false).Once()
sobConnCh <- []byte("nope")
bs, rAddr, err = conn.Receive()
n, rAddr, err = conn.ReadFrom(buf)
assert.Equal(t, err, io.EOF)
assert.Empty(t, rAddr)
assert.Empty(t, bs)
assert.Empty(t, buf[:n])
// The client should be disconnected
_, err = c.UDP()

View file

@ -11,6 +11,7 @@ import (
// This file provides utilities for the integration tests.
const (
udpBufferSize = 4096
testCertFile = "test.crt"
testKeyFile = "test.key"
)

View file

@ -97,14 +97,14 @@ func (a *PluggableOutboundAdapter) UDP(reqAddr string) (server.UDPConn, error) {
if err != nil {
return nil, err
}
return &udpConnAdapter{conn}, nil
return &UdpConnAdapter{conn}, nil
}
type udpConnAdapter struct {
type UdpConnAdapter struct {
UDPConn
}
func (u *udpConnAdapter) ReadFrom(b []byte) (int, string, error) {
func (u *UdpConnAdapter) ReadFrom(b []byte) (int, string, error) {
n, addr, err := u.UDPConn.ReadFrom(b)
if addr != nil {
return n, addr.String(), err
@ -113,7 +113,7 @@ func (u *udpConnAdapter) ReadFrom(b []byte) (int, string, error) {
}
}
func (u *udpConnAdapter) WriteTo(b []byte, addr string) (int, error) {
func (u *UdpConnAdapter) WriteTo(b []byte, addr string) (int, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return 0, err
@ -128,6 +128,6 @@ func (u *udpConnAdapter) WriteTo(b []byte, addr string) (int, error) {
})
}
func (u *udpConnAdapter) Close() error {
func (u *UdpConnAdapter) Close() error {
return u.UDPConn.Close()
}