Simplify code

This commit is contained in:
Toby 2020-10-02 18:23:47 -07:00
parent 2df70dafca
commit 05a34f8f92
16 changed files with 99 additions and 146 deletions

View file

@ -92,7 +92,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client connected")
return core.AuthSuccess, ""
return core.AuthResultSuccess, ""
} else {
// Need auth
ok, err := checkAuth(config.AuthFile, username, password)
@ -102,7 +102,7 @@ func proxyServer(args []string) {
"addr": addr.String(),
"username": username,
}).Error("Client authentication error")
return core.AuthInternalError, "Server auth error"
return core.AuthResultInternalError, "Server auth error"
}
if ok {
logrus.WithFields(logrus.Fields{
@ -111,7 +111,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client authenticated")
return core.AuthSuccess, ""
return core.AuthResultSuccess, ""
} else {
logrus.WithFields(logrus.Fields{
"addr": addr.String(),
@ -119,7 +119,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client rejected due to invalid credential")
return core.AuthInvalidCred, "Invalid credential"
return core.AuthResultInvalidCred, "Invalid credential"
}
}
},
@ -130,13 +130,14 @@ func proxyServer(args []string) {
"username": username,
}).Info("Client disconnected")
},
func(addr net.Addr, username string, id int, packet bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
if packet && config.DisableUDP {
return core.ConnBlocked, "UDP disabled", nil
return core.ConnectResultBlocked, "UDP disabled", nil
}
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
ip := net.ParseIP(host)
if ip != nil {
@ -163,9 +164,9 @@ func proxyServer(args []string) {
"error": err,
"dst": reqAddr,
}).Error("TCP error")
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
return core.ConnectResultSuccess, "", conn
} else {
// UDP
logrus.WithFields(logrus.Fields{
@ -180,9 +181,9 @@ func proxyServer(args []string) {
"error": err,
"dst": reqAddr,
}).Error("UDP error")
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
return core.ConnectResultSuccess, "", conn
}
case acl.ActionBlock:
if !packet {
@ -193,7 +194,7 @@ func proxyServer(args []string) {
"src": addr.String(),
"dst": reqAddr,
}).Debug("New TCP request")
return core.ConnBlocked, "blocked by ACL", nil
return core.ConnectResultBlocked, "blocked by ACL", nil
} else {
// UDP
logrus.WithFields(logrus.Fields{
@ -202,7 +203,7 @@ func proxyServer(args []string) {
"src": addr.String(),
"dst": reqAddr,
}).Debug("New UDP request")
return core.ConnBlocked, "blocked by ACL", nil
return core.ConnectResultBlocked, "blocked by ACL", nil
}
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
@ -221,9 +222,9 @@ func proxyServer(args []string) {
"error": err,
"dst": hijackAddr,
}).Error("TCP error")
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
return core.ConnectResultSuccess, "", conn
} else {
// UDP
logrus.WithFields(logrus.Fields{
@ -239,15 +240,16 @@ func proxyServer(args []string) {
"error": err,
"dst": hijackAddr,
}).Error("UDP error")
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
return core.ConnectResultSuccess, "", conn
}
default:
return core.ConnFailed, "server ACL error", nil
return core.ConnectResultFailed, "server ACL error", nil
}
},
func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error) {
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
packet := reqType == core.ConnectionTypePacket
if !packet {
logrus.WithFields(logrus.Fields{
"error": err,

View file

@ -6,10 +6,10 @@ import (
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
"github.com/tobyxdd/hysteria/internal/utils"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/utils"
"io/ioutil"
"net"
"os/user"
@ -99,7 +99,7 @@ func relayClient(args []string) {
}
}
func relayClientHandleConn(conn net.Conn, client core.Client) {
func relayClientHandleConn(conn net.Conn, client *core.Client) {
logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection")
var closeErr error
defer func() {

View file

@ -72,7 +72,7 @@ func relayServer(args []string) {
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client connected")
return core.AuthSuccess, ""
return core.AuthResultSuccess, ""
},
func(addr net.Addr, username string, err error) {
logrus.WithFields(logrus.Fields{
@ -81,14 +81,15 @@ func relayServer(args []string) {
"username": username,
}).Info("Client disconnected")
},
func(addr net.Addr, username string, id int, packet bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
logrus.WithFields(logrus.Fields{
"username": username,
"src": addr.String(),
"id": id,
}).Debug("New stream")
if packet {
return core.ConnBlocked, "unsupported", nil
return core.ConnectResultBlocked, "unsupported", nil
}
conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout)
if err != nil {
@ -96,11 +97,11 @@ func relayServer(args []string) {
"error": err,
"dst": config.RemoteAddr,
}).Error("TCP error")
return core.ConnFailed, err.Error(), nil
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnSuccess, "", conn
return core.ConnectResultSuccess, "", conn
},
func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error) {
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
logrus.WithFields(logrus.Fields{
"error": err,
"username": username,

View file

@ -6,7 +6,8 @@ import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/tobyxdd/hysteria/pkg/utils"
"net"
"sync"
"sync/atomic"
@ -56,11 +57,11 @@ func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
return nil, err
}
// Send request
req := &ClientConnectRequest{Address: addr}
req := &pb.ClientConnectRequest{Address: addr}
if packet {
req.Type = ConnectionType_Packet
req.Type = pb.ConnectionType_Packet
} else {
req.Type = ConnectionType_Stream
req.Type = pb.ConnectionType_Stream
}
err = writeClientConnectRequest(stream, req)
if err != nil {
@ -73,7 +74,7 @@ func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
_ = stream.Close()
return nil, err
}
if resp.Result != ConnectResult_CONN_SUCCESS {
if resp.Result != pb.ConnectResult_CONN_SUCCESS {
_ = stream.Close()
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
resp.Result.String(), resp.Message)
@ -135,7 +136,7 @@ func (c *Client) connectToServer() error {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return err
}
if result != AuthResult_AUTH_SUCCESS {
if result != pb.AuthResult_AUTH_SUCCESS {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
}
@ -144,13 +145,13 @@ func (c *Client) connectToServer() error {
return nil
}
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthResult, string, error) {
err := writeClientAuthRequest(stream, &ClientAuthRequest{
Credential: &Credential{
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (pb.AuthResult, string, error) {
err := writeClientAuthRequest(stream, &pb.ClientAuthRequest{
Credential: &pb.Credential{
Username: c.username,
Password: c.password,
},
Speed: &Speed{
Speed: &pb.Speed{
SendBps: c.sendBPS,
ReceiveBps: c.recvBPS,
},
@ -164,7 +165,7 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthR
return 0, "", err
}
// Set the congestion accordingly
if resp.Result == AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
if resp.Result == pb.AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps))
}
return resp.Result, resp.Message, nil

View file

@ -3,6 +3,7 @@ package core
import (
"encoding/binary"
"github.com/golang/protobuf/proto"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"io"
)
@ -30,17 +31,17 @@ func writeDataBlock(w io.Writer, data []byte) error {
return err
}
func readClientAuthRequest(r io.Reader) (*ClientAuthRequest, error) {
func readClientAuthRequest(r io.Reader) (*pb.ClientAuthRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req ClientAuthRequest
var req pb.ClientAuthRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error {
func writeClientAuthRequest(w io.Writer, req *pb.ClientAuthRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
@ -48,17 +49,17 @@ func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error {
return writeDataBlock(w, bs)
}
func readServerAuthResponse(r io.Reader) (*ServerAuthResponse, error) {
func readServerAuthResponse(r io.Reader) (*pb.ServerAuthResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp ServerAuthResponse
var resp pb.ServerAuthResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error {
func writeServerAuthResponse(w io.Writer, resp *pb.ServerAuthResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
@ -66,17 +67,17 @@ func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error {
return writeDataBlock(w, bs)
}
func readClientConnectRequest(r io.Reader) (*ClientConnectRequest, error) {
func readClientConnectRequest(r io.Reader) (*pb.ClientConnectRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req ClientConnectRequest
var req pb.ClientConnectRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error {
func writeClientConnectRequest(w io.Writer, req *pb.ClientConnectRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
@ -84,17 +85,17 @@ func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error {
return writeDataBlock(w, bs)
}
func readServerConnectResponse(r io.Reader) (*ServerConnectResponse, error) {
func readServerConnectResponse(r io.Reader) (*pb.ServerConnectResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp ServerConnectResponse
var resp pb.ServerConnectResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerConnectResponse(w io.Writer, resp *ServerConnectResponse) error {
func writeServerConnectResponse(w io.Writer, resp *pb.ServerConnectResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err

View file

@ -1,74 +0,0 @@
package core
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/core"
"io"
"net"
)
type AuthResult int32
const (
AuthSuccess = AuthResult(iota)
AuthInvalidCred
AuthInternalError
)
type ConnectResult int32
const (
ConnSuccess = ConnectResult(iota)
ConnFailed
ConnBlocked
)
type CongestionFactory core.CongestionFactory
type Obfuscator core.Obfuscator
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc core.ClientDisconnectedFunc
type HandleRequestFunc func(addr net.Addr, username string, id int, packet bool, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
type RequestClosedFunc func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error)
type Server interface {
Serve() error
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator Obfuscator,
clientAuthFunc ClientAuthFunc,
clientDisconnectedFunc ClientDisconnectedFunc,
handleRequestFunc HandleRequestFunc,
requestClosedFunc RequestClosedFunc) (Server, error) {
return core.NewServer(addr, tlsConfig, quicConfig, sendBPS, recvBPS, core.CongestionFactory(congestionFactory),
core.Obfuscator(obfuscator),
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
r, msg := clientAuthFunc(addr, username, password, sSend, sRecv)
return core.AuthResult(r), msg
},
core.ClientDisconnectedFunc(clientDisconnectedFunc),
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
r, msg, conn := handleRequestFunc(addr, username, id, reqType == core.ConnectionType_Packet, reqAddr)
return core.ConnectResult(r), msg, conn
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
requestClosedFunc(addr, username, id, reqType == core.ConnectionType_Packet, reqAddr, err)
})
}
type Client interface {
Dial(packet bool, addr string) (net.Conn, error)
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewClient(serverAddr string, username string, password string,
tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64,
congestionFactory CongestionFactory, obfuscator Obfuscator) (Client, error) {
return core.NewClient(serverAddr, username, password, tlsConfig, quicConfig, sendBPS, recvBPS,
core.CongestionFactory(congestionFactory), core.Obfuscator(obfuscator))
}

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: control.proto
package core
package pb
import (
fmt "fmt"

View file

@ -1,3 +1,3 @@
package core
package pb
//go:generate protoc --go_out=. control.proto

View file

@ -6,12 +6,34 @@ import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/tobyxdd/hysteria/pkg/utils"
"io"
"net"
"sync/atomic"
)
type AuthResult int32
type ConnectionType int32
type ConnectResult int32
const (
AuthResultSuccess AuthResult = iota
AuthResultInvalidCred
AuthResultInternalError
)
const (
ConnectionTypeStream ConnectionType = iota
ConnectionTypePacket
)
const (
ConnectResultSuccess ConnectResult = iota
ConnectResultFailed
ConnectResultBlocked
)
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc func(addr net.Addr, username string, err error)
type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
@ -140,10 +162,10 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (strin
authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password,
serverSendBPS, serverReceiveBPS)
// Response
err = writeServerAuthResponse(stream, &ServerAuthResponse{
Result: authResult,
err = writeServerAuthResponse(stream, &pb.ServerAuthResponse{
Result: pb.AuthResult(authResult),
Message: msg,
Speed: &Speed{
Speed: &pb.Speed{
SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS,
},
@ -152,10 +174,10 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (strin
return "", false, err
}
// Set the congestion accordingly
if authResult == AuthResult_AUTH_SUCCESS && s.congestionFactory != nil {
if authResult == AuthResultSuccess && s.congestionFactory != nil {
cs.SetCongestion(s.congestionFactory(serverSendBPS))
}
return req.Credential.Username, authResult == AuthResult_AUTH_SUCCESS, nil
return req.Credential.Username, authResult == AuthResultSuccess, nil
}
func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) {
@ -166,30 +188,30 @@ func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username
return
}
// Create connection with the handler
result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address)
result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address)
defer func() {
if conn != nil {
_ = conn.Close()
}
}()
// Send response
err = writeServerConnectResponse(stream, &ServerConnectResponse{
Result: result,
err = writeServerConnectResponse(stream, &pb.ServerConnectResponse{
Result: pb.ConnectResult(result),
Message: msg,
})
if err != nil {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address, err)
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
return
}
if result != ConnectResult_CONN_SUCCESS {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address,
fmt.Errorf("handler returned an unsuccessful state %s (msg: %s)", result.String(), msg))
if result != ConnectResultSuccess {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address,
fmt.Errorf("handler returned an unsuccessful state %d (msg: %s)", result, msg))
return
}
switch req.Type {
case ConnectionType_Stream:
case pb.ConnectionType_Stream:
err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes)
case ConnectionType_Packet:
case pb.ConnectionType_Packet:
err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{
Orig: stream,
PseudoLocalAddr: localAddr,
@ -198,5 +220,5 @@ func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username
default:
err = fmt.Errorf("unsupported connection type %s", req.Type.String())
}
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address, err)
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
}

View file

@ -14,7 +14,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/core"
)
func NewProxyHTTPServer(hyClient core.Client, idleTimeout time.Duration, aclEngine *acl.Engine,
func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEngine *acl.Engine,
newDialFunc func(reqAddr string, action acl.Action, arg string),
basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) {
proxy := goproxy.NewProxyHttpServer()

View file

@ -4,9 +4,9 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/tobyxdd/hysteria/internal/utils"
"github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/utils"
"io"
"strconv"
)
@ -23,7 +23,7 @@ var (
)
type Server struct {
HyClient core.Client
HyClient *core.Client
AuthFunc func(username, password string) bool
Method byte
TCPAddr *net.TCPAddr
@ -41,7 +41,7 @@ type Server struct {
tcpListener *net.TCPListener
}
func NewServer(hyClient core.Client, addr string, authFunc func(username, password string) bool, tcpDeadline int,
func NewServer(hyClient *core.Client, addr string, authFunc func(username, password string) bool, tcpDeadline int,
aclEngine *acl.Engine, disableUDP bool,
newReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string),
reqClosedFunc func(addr net.Addr, reqAddr string, err error),