hysteria/pkg/socks5/server.go
2022-11-16 22:57:35 -08:00

442 lines
11 KiB
Go

package socks5
import (
"encoding/binary"
"errors"
"fmt"
"strconv"
"github.com/apernet/hysteria/pkg/acl"
"github.com/apernet/hysteria/pkg/core"
"github.com/apernet/hysteria/pkg/transport"
"github.com/apernet/hysteria/pkg/utils"
)
import (
"net"
"time"
"github.com/txthinking/socks5"
)
const udpBufferSize = 4096
var (
ErrUnsupportedCmd = errors.New("unsupported command")
ErrUserPassAuth = errors.New("invalid username or password")
)
type Server struct {
HyClient *core.Client
Transport *transport.ClientTransport
AuthFunc func(username, password string) bool
Method byte
TCPAddr *net.TCPAddr
TCPTimeout time.Duration
ACLEngine *acl.Engine
DisableUDP bool
TCPRequestFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string)
TCPErrorFunc func(addr net.Addr, reqAddr string, err error)
UDPAssociateFunc func(addr net.Addr)
UDPErrorFunc func(addr net.Addr, err error)
tcpListener *net.TCPListener
}
func NewServer(hyClient *core.Client, transport *transport.ClientTransport, addr string,
authFunc func(username, password string) bool, tcpTimeout time.Duration,
aclEngine *acl.Engine, disableUDP bool,
tcpReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string),
tcpErrorFunc func(addr net.Addr, reqAddr string, err error),
udpAssocFunc func(addr net.Addr), udpErrorFunc func(addr net.Addr, err error),
) (*Server, error) {
tAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
m := socks5.MethodNone
if authFunc != nil {
m = socks5.MethodUsernamePassword
}
s := &Server{
HyClient: hyClient,
Transport: transport,
AuthFunc: authFunc,
Method: m,
TCPAddr: tAddr,
TCPTimeout: tcpTimeout,
ACLEngine: aclEngine,
DisableUDP: disableUDP,
TCPRequestFunc: tcpReqFunc,
TCPErrorFunc: tcpErrorFunc,
UDPAssociateFunc: udpAssocFunc,
UDPErrorFunc: udpErrorFunc,
}
return s, nil
}
func (s *Server) negotiate(c *net.TCPConn) error {
rq, err := socks5.NewNegotiationRequestFrom(c)
if err != nil {
return err
}
var got bool
var m byte
for _, m = range rq.Methods {
if m == s.Method {
got = true
}
}
if !got {
rp := socks5.NewNegotiationReply(socks5.MethodUnsupportAll)
if _, err := rp.WriteTo(c); err != nil {
return err
}
}
rp := socks5.NewNegotiationReply(s.Method)
if _, err := rp.WriteTo(c); err != nil {
return err
}
if s.Method == socks5.MethodUsernamePassword {
urq, err := socks5.NewUserPassNegotiationRequestFrom(c)
if err != nil {
return err
}
if !s.AuthFunc(string(urq.Uname), string(urq.Passwd)) {
urp := socks5.NewUserPassNegotiationReply(socks5.UserPassStatusFailure)
if _, err := urp.WriteTo(c); err != nil {
return err
}
return ErrUserPassAuth
}
urp := socks5.NewUserPassNegotiationReply(socks5.UserPassStatusSuccess)
if _, err := urp.WriteTo(c); err != nil {
return err
}
}
return nil
}
func (s *Server) ListenAndServe() error {
var err error
s.tcpListener, err = net.ListenTCP("tcp", s.TCPAddr)
if err != nil {
return err
}
defer s.tcpListener.Close()
for {
c, err := s.tcpListener.AcceptTCP()
if err != nil {
return err
}
go func() {
defer c.Close()
if s.TCPTimeout != 0 {
if err := c.SetDeadline(time.Now().Add(s.TCPTimeout)); err != nil {
return
}
}
if err := s.negotiate(c); err != nil {
return
}
r, err := socks5.NewRequestFrom(c)
if err != nil {
return
}
_ = s.handle(c, r)
}()
}
}
func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error {
if r.Cmd == socks5.CmdConnect {
// TCP
return s.handleTCP(c, r)
} else if r.Cmd == socks5.CmdUDP {
// UDP
if !s.DisableUDP {
return s.handleUDP(c, r)
} else {
_ = sendReply(c, socks5.RepCommandNotSupported)
return ErrUnsupportedCmd
}
} else {
_ = sendReply(c, socks5.RepCommandNotSupported)
return ErrUnsupportedCmd
}
}
func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
host, port, addr := parseRequestAddress(r)
action, arg := acl.ActionProxy, ""
var ipAddr *net.IPAddr
var resErr error
if s.ACLEngine != nil {
action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, false)
// Doesn't always matter if the resolution fails, as we may send it through HyClient
}
s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg)
var closeErr error
defer func() {
s.TCPErrorFunc(c.RemoteAddr(), addr, closeErr)
}()
// Handle according to the action
switch action {
case acl.ActionDirect:
if resErr != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = resErr
return resErr
}
rc, err := s.Transport.DialTCP(&net.TCPAddr{
IP: ipAddr.IP,
Port: int(port),
Zone: ipAddr.Zone,
})
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
return err
}
defer rc.Close()
_ = sendReply(c, socks5.RepSuccess)
closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
return nil
case acl.ActionProxy:
rc, err := s.HyClient.DialTCP(addr)
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
return err
}
defer rc.Close()
_ = sendReply(c, socks5.RepSuccess)
closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
return nil
case acl.ActionBlock:
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = errors.New("blocked in ACL")
return nil
case acl.ActionHijack:
hijackIPAddr, err := s.Transport.ResolveIPAddr(arg)
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
return err
}
rc, err := s.Transport.DialTCP(&net.TCPAddr{
IP: hijackIPAddr.IP,
Port: int(port),
Zone: hijackIPAddr.Zone,
})
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
return err
}
defer rc.Close()
_ = sendReply(c, socks5.RepSuccess)
closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
return nil
default:
_ = sendReply(c, socks5.RepServerFailure)
closeErr = fmt.Errorf("unknown action %d", action)
return nil
}
}
func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
s.UDPAssociateFunc(c.RemoteAddr())
var closeErr error
defer func() {
s.UDPErrorFunc(c.RemoteAddr(), closeErr)
}()
// Start local UDP server
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
IP: s.TCPAddr.IP,
Zone: s.TCPAddr.Zone,
})
if err != nil {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = err
return err
}
defer udpConn.Close()
// Local UDP relay conn for ACL Direct
var localRelayConn *net.UDPConn
if s.ACLEngine != nil {
localRelayConn, err = s.Transport.ListenUDP()
if err != nil {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = err
return err
}
defer localRelayConn.Close()
}
// HyClient UDP session
hyUDP, err := s.HyClient.DialUDP()
if err != nil {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = err
return err
}
defer hyUDP.Close()
// Send UDP server addr to the client
// Same IP as TCP but a different port
tcpLocalAddr := c.LocalAddr().(*net.TCPAddr)
var atyp byte
var addr, port []byte
if ip4 := tcpLocalAddr.IP.To4(); ip4 != nil {
atyp = socks5.ATYPIPv4
addr = ip4
} else if ip6 := tcpLocalAddr.IP.To16(); ip6 != nil {
atyp = socks5.ATYPIPv6
addr = ip6
} else {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = errors.New("invalid local addr")
return closeErr
}
port = make([]byte, 2)
binary.BigEndian.PutUint16(port, uint16(udpConn.LocalAddr().(*net.UDPAddr).Port))
_, _ = socks5.NewReply(socks5.RepSuccess, atyp, addr, port).WriteTo(c)
// Let UDP server do its job, we hold the TCP connection here
go s.udpServer(udpConn, localRelayConn, hyUDP)
if s.TCPTimeout != 0 {
// Disable TCP timeout for UDP holder
_ = c.SetDeadline(time.Time{})
}
buf := make([]byte, 1024)
for {
_, err := c.Read(buf)
if err != nil {
closeErr = err
break
}
}
// As the TCP connection closes, so does the UDP server & HyClient session
return nil
}
func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, hyUDP core.HyUDPConn) {
var clientAddr *net.UDPAddr
buf := make([]byte, udpBufferSize)
// Local to remote
for {
n, cAddr, err := clientConn.ReadFromUDP(buf)
if err != nil {
break
}
d, err := socks5.NewDatagramFromBytes(buf[:n])
if err != nil || d.Frag != 0 {
// Ignore bad packets
continue
}
if clientAddr == nil {
// Whoever sends the first valid packet is our client
clientAddr = cAddr
// Start remote to local
go func() {
for {
bs, from, err := hyUDP.ReadFrom()
if err != nil {
break
}
atyp, addr, port, err := socks5.ParseAddress(from)
if err != nil {
continue
}
d := socks5.NewDatagram(atyp, addr, port, bs)
_, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr)
}
}()
if localRelayConn != nil {
go func() {
buf := make([]byte, udpBufferSize)
for {
n, from, err := localRelayConn.ReadFrom(buf)
if n > 0 {
atyp, addr, port, err := socks5.ParseAddress(from.String())
if err != nil {
continue
}
d := socks5.NewDatagram(atyp, addr, port, buf[:n])
_, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr)
}
if err != nil {
break
}
}
}()
}
} else if cAddr.String() != clientAddr.String() {
// Not our client, bye
continue
}
host, port, addr := parseDatagramRequestAddress(d)
action, arg := acl.ActionProxy, ""
var ipAddr *net.IPAddr
var resErr error
if s.ACLEngine != nil && localRelayConn != nil {
action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, true)
// Doesn't always matter if the resolution fails, as we may send it through HyClient
}
// Handle according to the action
switch action {
case acl.ActionDirect:
if resErr != nil {
return
}
_, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{
IP: ipAddr.IP,
Port: int(port),
Zone: ipAddr.Zone,
})
case acl.ActionProxy:
_ = hyUDP.WriteTo(d.Data, addr)
case acl.ActionBlock:
// Do nothing
case acl.ActionHijack:
hijackIPAddr, err := s.Transport.ResolveIPAddr(arg)
if err == nil {
_, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{
IP: hijackIPAddr.IP,
Port: int(port),
Zone: hijackIPAddr.Zone,
})
}
default:
// Do nothing
}
}
}
func sendReply(conn *net.TCPConn, rep byte) error {
p := socks5.NewReply(rep, socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00})
_, err := p.WriteTo(conn)
return err
}
func parseRequestAddress(r *socks5.Request) (host string, port uint16, addr string) {
p := binary.BigEndian.Uint16(r.DstPort)
if r.Atyp == socks5.ATYPDomain {
d := string(r.DstAddr[1:])
return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
} else {
ipStr := net.IP(r.DstAddr).String()
return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
}
}
func parseDatagramRequestAddress(r *socks5.Datagram) (host string, port uint16, addr string) {
p := binary.BigEndian.Uint16(r.DstPort)
if r.Atyp == socks5.ATYPDomain {
d := string(r.DstAddr[1:])
return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
} else {
ipStr := net.IP(r.DstAddr).String()
return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
}
}