package core import ( "context" "crypto/tls" "errors" "fmt" "github.com/lucas-clemente/quic-go" "github.com/tobyxdd/hysteria/internal/utils" "io" "net" "sync/atomic" ) type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string) type ClientDisconnectedFunc func(addr net.Addr, username string, err error) type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser) type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error) type Server struct { inboundBytes, outboundBytes uint64 // atomic listener quic.Listener sendBPS, recvBPS uint64 congestionFactory CongestionFactory clientAuthFunc ClientAuthFunc clientDisconnectedFunc ClientDisconnectedFunc handleRequestFunc HandleRequestFunc requestClosedFunc RequestClosedFunc } func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator, clientAuthFunc ClientAuthFunc, clientDisconnectedFunc ClientDisconnectedFunc, handleRequestFunc HandleRequestFunc, requestClosedFunc RequestClosedFunc) (*Server, error) { packetConn, err := net.ListenPacket("udp", addr) if err != nil { return nil, err } if obfuscator != nil { // Wrap PacketConn with obfuscator packetConn = &obfsPacketConn{ Orig: packetConn, Obfuscator: obfuscator, } } listener, err := quic.Listen(packetConn, tlsConfig, quicConfig) if err != nil { return nil, err } s := &Server{ listener: listener, sendBPS: sendBPS, recvBPS: recvBPS, congestionFactory: congestionFactory, clientAuthFunc: clientAuthFunc, clientDisconnectedFunc: clientDisconnectedFunc, handleRequestFunc: handleRequestFunc, requestClosedFunc: requestClosedFunc, } return s, nil } func (s *Server) Serve() error { for { cs, err := s.listener.Accept(context.Background()) if err != nil { return err } go s.handleClient(cs) } } func (s *Server) Stats() (uint64, uint64) { return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes) } func (s *Server) Close() error { return s.listener.Close() } func (s *Server) handleClient(cs quic.Session) { // Expect the client to create a control stream to send its own information ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) ctlStream, err := cs.AcceptStream(ctx) ctxCancel() if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") return } // Handle the control stream username, ok, err := s.handleControlStream(cs, ctlStream) if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") return } if !ok { _ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure") return } // Start accepting streams var closeErr error for { stream, err := cs.AcceptStream(context.Background()) if err != nil { closeErr = err break } go s.handleStream(cs.RemoteAddr(), username, stream) } s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr) _ = cs.CloseWithError(closeErrorCodeGeneric, "generic") } // Auth & negotiate speed func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) { req, err := readClientAuthRequest(stream) if err != nil { return "", false, err } // Speed if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 { return "", false, errors.New("incorrect speed provided by the client") } serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps if s.sendBPS > 0 && serverSendBPS > s.sendBPS { serverSendBPS = s.sendBPS } if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS { serverReceiveBPS = s.recvBPS } // Auth if req.Credential == nil { return "", false, errors.New("incorrect credential provided by the client") } authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password, serverSendBPS, serverReceiveBPS) // Response err = writeServerAuthResponse(stream, &ServerAuthResponse{ Result: authResult, Message: msg, Speed: &Speed{ SendBps: serverSendBPS, ReceiveBps: serverReceiveBPS, }, }) if err != nil { return "", false, err } // Set the congestion accordingly if authResult == AuthResult_AUTH_SUCCESS && s.congestionFactory != nil { cs.SetCongestion(s.congestionFactory(serverSendBPS)) } return req.Credential.Username, authResult == AuthResult_AUTH_SUCCESS, nil } func (s *Server) handleStream(addr net.Addr, username string, stream quic.Stream) { defer stream.Close() // Read request req, err := readClientConnectRequest(stream) if err != nil { return } // Create connection with the handler result, msg, conn := s.handleRequestFunc(addr, username, int(stream.StreamID()), req.Type, req.Address) defer func() { if conn != nil { _ = conn.Close() } }() // Send response err = writeServerConnectResponse(stream, &ServerConnectResponse{ Result: result, Message: msg, }) if err != nil { s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err) return } if result != ConnectResult_CONN_SUCCESS { s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, fmt.Errorf("handler returned an unsuccessful state %s (msg: %s)", result.String(), msg)) return } switch req.Type { case ConnectionType_Stream: err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes) case ConnectionType_Packet: err = utils.PipePair(&utils.PacketReadWriteCloser{Orig: stream}, conn, &s.outboundBytes, &s.inboundBytes) default: err = fmt.Errorf("unsupported connection type %s", req.Type.String()) } s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err) }