Split host & port in the protocol, and make each domain resolves only once even when ACL is enabled, improving performance and ensuring consistency of connection destinations

This commit is contained in:
Toby 2021-04-19 00:20:22 -07:00
parent 7b841aa203
commit b09880a050
8 changed files with 196 additions and 136 deletions

View file

@ -52,39 +52,47 @@ func LoadFromFile(filename string) (*Engine, error) {
}, nil
}
func (e *Engine) Lookup(domain string, ip net.IP) (Action, string) {
if len(domain) > 0 {
func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, error) {
ip, zone := parseIPZone(host)
if ip == nil {
// Domain
if v, ok := e.Cache.Get(domain); ok {
ipAddr, err := net.ResolveIPAddr("ip", host)
if v, ok := e.Cache.Get(host); ok {
// Cache hit
ce := v.(cacheEntry)
return ce.Action, ce.Arg
return ce.Action, ce.Arg, ipAddr, err
}
ips, _ := net.LookupIP(domain)
for _, entry := range e.Entries {
if entry.MatchDomain(domain) || (len(ips) > 0 && entry.MatchIPs(ips)) {
e.Cache.Add(domain, cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP)) {
e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, ipAddr, err
}
}
e.Cache.Add(domain, cacheEntry{e.DefaultAction, ""})
return e.DefaultAction, ""
} else if ip != nil {
e.Cache.Add(host, cacheEntry{e.DefaultAction, ""})
return e.DefaultAction, "", ipAddr, err
} else {
// IP
if v, ok := e.Cache.Get(ip.String()); ok {
// Cache hit
ce := v.(cacheEntry)
return ce.Action, ce.Arg
return ce.Action, ce.Arg, &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
for _, entry := range e.Entries {
if entry.MatchIP(ip) {
e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg
return entry.Action, entry.ActionArg, &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
}
e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""})
return e.DefaultAction, ""
} else {
return e.DefaultAction, ""
return e.DefaultAction, "", &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
}

View file

@ -6,7 +6,7 @@ import (
"testing"
)
func TestEngine_Lookup(t *testing.T) {
func TestEngine_ResolveAndMatch(t *testing.T) {
cache, _ := lru.NewARC(4)
e := &Engine{
DefaultAction: ActionDirect,
@ -49,61 +49,65 @@ func TestEngine_Lookup(t *testing.T) {
},
Cache: cache,
}
type args struct {
domain string
ip net.IP
}
tests := []struct {
name string
args args
want Action
want1 string
name string
addr string
want Action
want1 string
wantErr bool
}{
{
name: "domain direct",
args: args{"google.com", nil},
addr: "google.com",
want: ActionProxy,
want1: "",
},
{
name: "domain suffix 1",
args: args{"evil.corp", nil},
want: ActionHijack,
want1: "good.org",
name: "domain suffix 1",
addr: "evil.corp",
want: ActionHijack,
want1: "good.org",
wantErr: true,
},
{
name: "domain suffix 2",
args: args{"notevil.corp", nil},
want: ActionBlock,
want1: "",
name: "domain suffix 2",
addr: "notevil.corp",
want: ActionBlock,
want1: "",
wantErr: true,
},
{
name: "domain suffix 3",
args: args{"im.real.evil.corp", nil},
want: ActionHijack,
want1: "good.org",
name: "domain suffix 3",
addr: "im.real.evil.corp",
want: ActionHijack,
want1: "good.org",
wantErr: true,
},
{
name: "ip match",
args: args{"", net.ParseIP("10.2.3.4")},
addr: "10.2.3.4",
want: ActionProxy,
want1: "",
},
{
name: "ip mismatch",
args: args{"", net.ParseIP("100.5.6.0")},
addr: "100.5.6.0",
want: ActionBlock,
want1: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := e.Lookup(tt.args.domain, tt.args.ip)
got, got1, _, err := e.ResolveAndMatch(tt.addr)
if (err != nil) != tt.wantErr {
t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Lookup() got = %v, want %v", got, tt.want)
t.Errorf("ResolveAndMatch() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Lookup() got1 = %v, want %v", got1, tt.want1)
t.Errorf("ResolveAndMatch() got1 = %v, want %v", got1, tt.want1)
}
})
}

View file

@ -50,20 +50,6 @@ func (e Entry) MatchIP(ip net.IP) bool {
return false
}
func (e Entry) MatchIPs(ips []net.IP) bool {
if e.All {
return true
}
if e.Net != nil && len(ips) > 0 {
for _, ip := range ips {
if e.Net.Contains(ip) {
return true
}
}
}
return false
}
// Format: action cond_type cond arg
// Examples:
// proxy domain-suffix google.com

View file

@ -10,6 +10,7 @@ import (
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lunixbochs/struc"
"net"
"strconv"
"sync"
"time"
)
@ -187,14 +188,19 @@ func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) {
}
func (c *Client) DialTCP(addr string) (net.Conn, error) {
host, port, err := splitHostPort(addr)
if err != nil {
return nil, err
}
session, stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: false,
Address: addr,
UDP: false,
Host: host,
Port: port,
})
if err != nil {
_ = stream.Close()
@ -349,14 +355,19 @@ func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
// Closed
return nil, "", ErrClosed
}
return msg.Data, msg.Address, nil
return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil
}
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
host, port, err := splitHostPort(addr)
if err != nil {
return err
}
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: c.UDPSessionID,
Address: addr,
Host: host,
Port: port,
Data: p,
})
return c.Session.SendMessage(msgBuf.Bytes())
@ -366,3 +377,15 @@ func (c *quicPktConn) Close() error {
c.CloseFunc()
return c.Stream.Close()
}
func splitHostPort(hostport string) (string, uint16, error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
return "", 0, err
}
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return "", 0, err
}
return host, uint16(portUint), err
}

View file

@ -32,9 +32,10 @@ type serverHello struct {
}
type clientRequest struct {
UDP bool
AddressLen uint16 `struc:"sizeof=Address"`
Address string
UDP bool
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
}
type serverResponse struct {
@ -45,9 +46,10 @@ type serverResponse struct {
}
type udpMessage struct {
SessionID uint32
AddressLen uint16 `struc:"sizeof=Address"`
Address string
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
SessionID uint32
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
}

View file

@ -10,6 +10,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/utils"
"net"
"strconv"
"sync"
)
@ -88,7 +89,7 @@ func (c *serverClient) handleStream(stream quic.Stream) {
}
if !req.UDP {
// TCP connection
c.handleTCP(stream, req.Address)
c.handleTCP(stream, req.Host, req.Port)
} else if !c.DisableUDP {
// UDP connection
c.handleUDP(stream)
@ -112,32 +113,30 @@ func (c *serverClient) handleMessage(msg []byte) {
c.udpSessionMutex.RUnlock()
if ok {
// Session found, send the message
host, port, err := net.SplitHostPort(udpMsg.Address)
action, arg := acl.ActionDirect, ""
var ipAddr *net.IPAddr
if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host)
} else {
ipAddr, err = net.ResolveIPAddr("ip", udpMsg.Host)
}
if err != nil {
return
}
action, arg := acl.ActionDirect, ""
if c.ACLEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg = c.ACLEngine.Lookup(host, ip)
}
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
addr, err := net.ResolveUDPAddr("udp", udpMsg.Address)
if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
if c.UpCounter != nil {
c.UpCounter.Add(float64(len(udpMsg.Data)))
}
_, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{
IP: ipAddr.IP,
Port: int(udpMsg.Port),
Zone: ipAddr.Zone,
})
if c.UpCounter != nil {
c.UpCounter.Add(float64(len(udpMsg.Data)))
}
case acl.ActionBlock:
// Do nothing
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port)))
addr, err := net.ResolveUDPAddr("udp", hijackAddr)
if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
@ -151,37 +150,40 @@ func (c *serverClient) handleMessage(msg []byte) {
}
}
func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
host, port, err := net.SplitHostPort(reqAddr)
func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
addrStr := net.JoinHostPort(host, strconv.Itoa(int(port)))
action, arg := acl.ActionDirect, ""
var ipAddr *net.IPAddr
var err error
if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host)
} else {
ipAddr, err = net.ResolveIPAddr("ip", host)
}
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "invalid address",
Message: "host resolution failure",
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
return
}
action, arg := acl.ActionDirect, ""
if c.ACLEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg = c.ACLEngine.Lookup(host, ip)
}
c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg)
c.CTCPRequestFunc(c.ClientAddr, c.Auth, addrStr, action, arg)
var conn net.Conn // Connection to be piped
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
IP: ipAddr.IP,
Port: int(port),
Zone: ipAddr.Zone,
})
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
return
}
case acl.ActionBlock:
@ -191,14 +193,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
})
return
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
conn, err = net.Dial("tcp", hijackAddr)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
return
}
default:
@ -227,7 +229,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
} else {
err = utils.Pipe2Way(stream, conn, nil)
}
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
}
func (c *serverClient) handleUDP(stream quic.Stream) {
@ -268,7 +270,8 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: id,
Address: rAddr.String(),
Host: rAddr.IP.String(),
Port: uint16(rAddr.Port),
Data: buf[:n],
})
_ = c.CS.SendMessage(msgBuf.Bytes())

View file

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"time"
"github.com/elazarl/goproxy/ext/auth"
@ -27,20 +28,30 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
if err != nil {
return nil, err
}
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, err
}
// ACL
action, arg := acl.ActionProxy, ""
var ipAddr *net.IPAddr
var resErr error
if aclEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
host = ""
}
action, arg = aclEngine.Lookup(host, ip)
action, arg, ipAddr, resErr = aclEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient
}
newDialFunc(addr, action, arg)
// Handle according to the action
switch action {
case acl.ActionDirect:
return net.Dial(network, addr)
if resErr != nil {
return nil, resErr
}
return net.DialTCP(network, nil, &net.TCPAddr{
IP: ipAddr.IP,
Port: int(portUint),
Zone: ipAddr.Zone,
})
case acl.ActionProxy:
return hyClient.DialTCP(addr)
case acl.ActionBlock:

View file

@ -162,10 +162,13 @@ func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error {
}
func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
domain, ip, port, addr := parseRequestAddress(r)
host, port, addr := parseRequestAddress(r)
action, arg := acl.ActionProxy, ""
var ipAddr *net.IPAddr
var resErr error
if s.ACLEngine != nil {
action, arg = s.ACLEngine.Lookup(domain, ip)
action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient
}
s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg)
var closeErr error
@ -175,7 +178,16 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
// Handle according to the action
switch action {
case acl.ActionDirect:
rc, err := net.Dial("tcp", addr)
if resErr != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = resErr
return resErr
}
rc, err := net.DialTCP("tcp", nil, &net.TCPAddr{
IP: ipAddr.IP,
Port: int(port),
Zone: ipAddr.Zone,
})
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
@ -201,7 +213,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
closeErr = errors.New("blocked in ACL")
return nil
case acl.ActionHijack:
rc, err := net.Dial("tcp", net.JoinHostPort(arg, port))
rc, err := net.Dial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port))))
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
@ -299,13 +311,15 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn,
// Start remote to local
go func() {
for {
bs, _, err := hyUDP.ReadFrom()
bs, from, err := hyUDP.ReadFrom()
if err != nil {
break
}
// RFC 1928 is very ambiguous on how to properly use DST.ADDR and DST.PORT in reply packets
// So we just fill in zeros for now. Works fine for all the SOCKS5 clients I tested
d := socks5.NewDatagram(socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}, bs)
atyp, addr, port, err := socks5.ParseAddress(from)
if err != nil {
continue
}
d := socks5.NewDatagram(atyp, addr, port, bs)
_, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr)
}
}()
@ -329,24 +343,31 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn,
// Not our client, bye
continue
}
domain, ip, port, addr := parseDatagramRequestAddress(d)
host, port, addr := parseDatagramRequestAddress(d)
action, arg := acl.ActionProxy, ""
var ipAddr *net.IPAddr
var resErr error
if s.ACLEngine != nil && localRelayConn != nil {
action, arg = s.ACLEngine.Lookup(domain, ip)
action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient
}
// Handle according to the action
switch action {
case acl.ActionDirect:
rAddr, err := net.ResolveUDPAddr("udp", addr)
if err == nil {
_, _ = localRelayConn.WriteToUDP(d.Data, rAddr)
if resErr != nil {
return
}
_, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{
IP: ipAddr.IP,
Port: int(port),
Zone: ipAddr.Zone,
})
case acl.ActionProxy:
_ = hyUDP.WriteTo(d.Data, addr)
case acl.ActionBlock:
// Do nothing
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
hijackAddr := net.JoinHostPort(arg, net.JoinHostPort(arg, strconv.Itoa(int(port))))
rAddr, err := net.ResolveUDPAddr("udp", hijackAddr)
if err == nil {
_, _ = localRelayConn.WriteToUDP(d.Data, rAddr)
@ -363,22 +384,24 @@ func sendReply(conn *net.TCPConn, rep byte) error {
return err
}
func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port string, addr string) {
p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort)))
func parseRequestAddress(r *socks5.Request) (host string, port uint16, addr string) {
p := binary.BigEndian.Uint16(r.DstPort)
if r.Atyp == socks5.ATYPDomain {
d := string(r.DstAddr[1:])
return d, nil, p, net.JoinHostPort(d, p)
return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
} else {
return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p)
ipStr := net.IP(r.DstAddr).String()
return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
}
}
func parseDatagramRequestAddress(r *socks5.Datagram) (domain string, ip net.IP, port string, addr string) {
p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort)))
func parseDatagramRequestAddress(r *socks5.Datagram) (host string, port uint16, addr string) {
p := binary.BigEndian.Uint16(r.DstPort)
if r.Atyp == socks5.ATYPDomain {
d := string(r.DstAddr[1:])
return d, nil, p, net.JoinHostPort(d, p)
return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
} else {
return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p)
ipStr := net.IP(r.DstAddr).String()
return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
}
}