diff --git a/Makefile.am b/Makefile.am index 1a63b8f4..a642e343 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1,4 +1,4 @@ -SUBDIRS = po intl lib src test doc +SUBDIRS = po intl lib deps src test doc ACLOCAL_AMFLAGS = -I m4 --install diff --git a/configure.ac b/configure.ac index a617a8f6..281bc12a 100644 --- a/configure.ac +++ b/configure.ac @@ -2,7 +2,9 @@ # Process this file with autoconf to produce a configure script. # AC_PREREQ([2.67]) +LT_PREREQ([2.2.6]) AC_INIT([aria2],[1.14.2],[t-tujikawa@users.sourceforge.net],[aria2],[http://aria2.sourceforge.net/]) +LT_INIT() AC_CANONICAL_HOST AC_CANONICAL_TARGET AM_INIT_AUTOMAKE() @@ -44,7 +46,6 @@ AC_PROG_CXX AC_PROG_CC AC_PROG_INSTALL AC_PROG_MKDIR_P -AC_PROG_RANLIB AC_PROG_YACC AC_CHECK_TOOL([AR], [ar], [:]) @@ -526,6 +527,10 @@ if test "x$have_option_const_name" = "xyes"; then AC_DEFINE([HAVE_OPTION_CONST_NAME], [1], [Define 1 if struct option.name is const char*]) fi +AC_CONFIG_SUBDIRS([deps/wslay]) +LIBS="\$(top_builddir)/deps/wslay/lib/libwslay.la $LIBS" +CPPFLAGS="-I\$(top_builddir)/deps/wslay/lib/includes $CPPFLAGS" + AC_CONFIG_FILES([Makefile src/Makefile test/Makefile @@ -533,7 +538,8 @@ AC_CONFIG_FILES([Makefile intl/Makefile lib/Makefile doc/Makefile - doc/ru/Makefile]) + doc/ru/Makefile + deps/Makefile]) AC_OUTPUT echo " " diff --git a/deps/Makefile.am b/deps/Makefile.am new file mode 100644 index 00000000..e42b32ee --- /dev/null +++ b/deps/Makefile.am @@ -0,0 +1 @@ +SUBDIRS = wslay diff --git a/src/AbstractHttpServerResponseCommand.cc b/src/AbstractHttpServerResponseCommand.cc new file mode 100644 index 00000000..d60e81c4 --- /dev/null +++ b/src/AbstractHttpServerResponseCommand.cc @@ -0,0 +1,101 @@ +/* */ +#include "AbstractHttpServerResponseCommand.h" +#include "SocketCore.h" +#include "DownloadEngine.h" +#include "HttpServer.h" +#include "Logger.h" +#include "LogFactory.h" +#include "HttpServerCommand.h" +#include "RequestGroupMan.h" +#include "RecoverableException.h" +#include "wallclock.h" +#include "util.h" +#include "fmt.h" + +namespace aria2 { + +AbstractHttpServerResponseCommand::AbstractHttpServerResponseCommand +(cuid_t cuid, + const SharedHandle& httpServer, + DownloadEngine* e, + const SharedHandle& socket) + : Command(cuid), + e_(e), + socket_(socket), + httpServer_(httpServer) +{ + setStatus(Command::STATUS_ONESHOT_REALTIME); + e_->addSocketForWriteCheck(socket_, this); +} + +AbstractHttpServerResponseCommand::~AbstractHttpServerResponseCommand() +{ + e_->deleteSocketForWriteCheck(socket_, this); +} + +bool AbstractHttpServerResponseCommand::execute() +{ + if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { + return true; + } + try { + httpServer_->sendResponse(); + } catch(RecoverableException& e) { + A2_LOG_INFO_EX + (fmt("CUID#%lld - Error occurred while transmitting response body.", + getCuid()), + e); + return true; + } + if(httpServer_->sendBufferIsEmpty()) { + A2_LOG_INFO(fmt("CUID#%lld - HttpServer: all response transmitted.", + getCuid())); + afterSend(httpServer_, e_); + return true; + } else { + if(timeoutTimer_.difference(global::wallclock()) >= 10) { + A2_LOG_INFO(fmt("CUID#%lld - HttpServer: Timeout while trasmitting" + " response.", + getCuid())); + return true; + } else { + e_->addCommand(this); + return false; + } + } +} + +} // namespace aria2 diff --git a/src/AbstractHttpServerResponseCommand.h b/src/AbstractHttpServerResponseCommand.h new file mode 100644 index 00000000..7dbc71dd --- /dev/null +++ b/src/AbstractHttpServerResponseCommand.h @@ -0,0 +1,75 @@ +/* */ +#ifndef D_ABSTRACT_HTTP_SERVER_RESPONSE_COMMAND_H +#define D_ABSTRACT_HTTP_SERVER_RESPONSE_COMMAND_H + +#include "Command.h" +#include "SharedHandle.h" +#include "TimerA2.h" + +namespace aria2 { + +class DownloadEngine; +class SocketCore; +class HttpServer; + +class AbstractHttpServerResponseCommand : public Command { +private: + DownloadEngine* e_; + SharedHandle socket_; + SharedHandle httpServer_; + Timer timeoutTimer_; +protected: + DownloadEngine* getDownloadEngine() + { + return e_; + } + // Called after content body is completely sent. + virtual void afterSend(const SharedHandle& httpServer, + DownloadEngine* e) = 0; +public: + AbstractHttpServerResponseCommand(cuid_t cuid, + const SharedHandle& httpServer, + DownloadEngine* e, + const SharedHandle& socket); + + virtual ~AbstractHttpServerResponseCommand(); + + virtual bool execute(); +}; + +} // namespace aria2 + +#endif // D_ABSTRACT_HTTP_SERVER_RESPONSE_COMMAND_H diff --git a/src/DefaultPieceStorage.cc b/src/DefaultPieceStorage.cc index a338b6bd..ff0cc1bb 100644 --- a/src/DefaultPieceStorage.cc +++ b/src/DefaultPieceStorage.cc @@ -63,6 +63,8 @@ #include "PieceStatMan.h" #include "wallclock.h" #include "bitfield.h" +#include "SingletonHolder.h" +#include "Notifier.h" #ifdef ENABLE_BITTORRENT # include "bittorrent_helper.h" #endif // ENABLE_BITTORRENT @@ -479,6 +481,9 @@ void DefaultPieceStorage::completePiece(const SharedHandle& piece) if(!torrentAttrs->metadata.empty()) { util::executeHookByOptName(downloadContext_->getOwnerRequestGroup(), option_, PREF_ON_BT_DOWNLOAD_COMPLETE); + SingletonHolder::instance()-> + notifyDownloadEvent(Notifier::ON_BT_DOWNLOAD_COMPLETE, + downloadContext_->getOwnerRequestGroup()); } } #endif // ENABLE_BITTORRENT diff --git a/src/HttpServer.cc b/src/HttpServer.cc index e797dbff..b48ecebd 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -212,6 +212,20 @@ void HttpServer::feedResponse(const std::string& status, socketBuffer_.pushStr(text); } +void HttpServer::feedUpgradeResponse(const std::string& protocol, + const std::string& headers) +{ + std::string header= fmt("HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: %s\r\n" + "Connection: Upgrade\r\n" + "%s" + "\r\n", + protocol.c_str(), + headers.c_str()); + A2_LOG_DEBUG(fmt("HTTP Server sends upgrade response:\n%s", header.c_str())); + socketBuffer_.pushStr(header); +} + ssize_t HttpServer::sendResponse() { return socketBuffer_.send(); diff --git a/src/HttpServer.h b/src/HttpServer.h index f5ebb4ee..e7d7c2dc 100644 --- a/src/HttpServer.h +++ b/src/HttpServer.h @@ -91,6 +91,12 @@ public: std::string& text, const std::string& contentType); + // Feeds "101 Switching Protocols" response. The |protocol| will + // appear in Upgrade header field. The |headers| is zero or more + // lines of HTTP header field and each line must end with "\r\n". + void feedUpgradeResponse(const std::string& protocol, + const std::string& headers); + bool authenticate(); void setUsernamePassword @@ -129,6 +135,11 @@ public: { allowOrigin_ = allowOrigin; } + + const SharedHandle& getSocket() const + { + return socket_; + } }; } // namespace aria2 diff --git a/src/HttpServerBodyCommand.cc b/src/HttpServerBodyCommand.cc index 9e0db7a9..48581532 100644 --- a/src/HttpServerBodyCommand.cc +++ b/src/HttpServerBodyCommand.cc @@ -86,21 +86,6 @@ HttpServerBodyCommand::~HttpServerBodyCommand() e_->deleteSocketForReadCheck(socket_, this); } -namespace { -rpc::RpcResponse -createJsonRpcErrorResponse -(int code, - const std::string& msg, - const SharedHandle& id) -{ - SharedHandle params = Dict::g(); - params->put("code", Integer::g(code)); - params->put("message", msg); - rpc::RpcResponse res(code, params, id); - return res; -} -} // namespace - namespace { std::string getJsonRpcContentType(bool script) { @@ -156,42 +141,6 @@ void HttpServerBodyCommand::addHttpServerResponseCommand() e_->setNoWait(true); } -rpc::RpcResponse -HttpServerBodyCommand::processJsonRpcRequest(const Dict* jsondict) -{ - - SharedHandle id = jsondict->get("id"); - if(!id) { - return createJsonRpcErrorResponse(-32600, "Invalid Request.", Null::g()); - } - const String* methodName = downcast(jsondict->get("method")); - if(!methodName) { - return createJsonRpcErrorResponse(-32600, "Invalid Request.", id); - } - SharedHandle params; - const SharedHandle& tempParams = jsondict->get("params"); - if(downcast(tempParams)) { - params = static_pointer_cast(tempParams); - } else if(!tempParams) { - params = List::g(); - } else { - // TODO No support for Named params - return createJsonRpcErrorResponse(-32602, "Invalid params.", id); - } - rpc::RpcRequest req(methodName->s(), params, id); - req.jsonRpc = true; - SharedHandle method; - try { - method = rpc::RpcMethodFactory::create(req.methodName); - } catch(RecoverableException& e) { - A2_LOG_INFO_EX(EX_EXCEPTION_CAUGHT, e); - return createJsonRpcErrorResponse(-32601, "Method not found.", id); - } - A2_LOG_INFO(fmt("Executing RPC method %s", req.methodName.c_str())); - rpc::RpcResponse res = method->execute(req, e_); - return res; -} - bool HttpServerBodyCommand::execute() { if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { @@ -242,13 +191,14 @@ bool HttpServerBodyCommand::execute() getCuid()), e); rpc::RpcResponse res - (createJsonRpcErrorResponse(-32700, "Parse error.", Null::g())); + (rpc::createJsonRpcErrorResponse(-32700, "Parse error.", + Null::g())); sendJsonRpcResponse(res, callback); return true; } const Dict* jsondict = downcast(json); if(jsondict) { - rpc::RpcResponse res = processJsonRpcRequest(jsondict); + rpc::RpcResponse res = rpc::processJsonRpcRequest(jsondict, e_); sendJsonRpcResponse(res, callback); } else { const List* jsonlist = downcast(json); @@ -259,14 +209,14 @@ bool HttpServerBodyCommand::execute() eoi = jsonlist->end(); i != eoi; ++i) { const Dict* jsondict = downcast(*i); if(jsondict) { - rpc::RpcResponse r = processJsonRpcRequest(jsondict); + rpc::RpcResponse r = rpc::processJsonRpcRequest(jsondict, e_); results.push_back(r); } } sendJsonRpcBatchResponse(results, callback); } else { rpc::RpcResponse res - (createJsonRpcErrorResponse + (rpc::createJsonRpcErrorResponse (-32600, "Invalid Request.", Null::g())); sendJsonRpcResponse(res, callback); } diff --git a/src/HttpServerBodyCommand.h b/src/HttpServerBodyCommand.h index f8fcc881..30ed7851 100644 --- a/src/HttpServerBodyCommand.h +++ b/src/HttpServerBodyCommand.h @@ -65,7 +65,6 @@ private: void sendJsonRpcBatchResponse (const std::vector& results, const std::string& callback); - rpc::RpcResponse processJsonRpcRequest(const Dict* jsondict); void addHttpServerResponseCommand(); public: HttpServerBodyCommand(cuid_t cuid, diff --git a/src/HttpServerCommand.cc b/src/HttpServerCommand.cc index d5653c92..18aae01c 100644 --- a/src/HttpServerCommand.cc +++ b/src/HttpServerCommand.cc @@ -43,6 +43,7 @@ #include "RequestGroupMan.h" #include "HttpServerBodyCommand.h" #include "HttpServerResponseCommand.h" +#include "WebSocketResponseCommand.h" #include "RecoverableException.h" #include "prefs.h" #include "Option.h" @@ -50,6 +51,9 @@ #include "wallclock.h" #include "fmt.h" #include "SocketRecvBuffer.h" +#include "MessageDigest.h" +#include "message_digest_helper.h" +#include "base64.h" namespace aria2 { @@ -104,6 +108,19 @@ void HttpServerCommand::checkSocketRecvBuffer() } } +// Creates server's WebSocket accept key which will be sent in +// Sec-WebSocket-Accept header field. The |clientKey| is the value +// found in Sec-WebSocket-Key header field in the request. +std::string createWebSocketServerKey(const std::string& clientKey) +{ + std::string src = clientKey; + src += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + unsigned char digest[20]; + message_digest::digest(digest, sizeof(digest), MessageDigest::sha1(), + src.c_str(), src.size()); + return base64::encode(&digest[0], &digest[sizeof(digest)]); +} + bool HttpServerCommand::execute() { if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { @@ -133,20 +150,41 @@ bool HttpServerCommand::execute() e_->setNoWait(true); return true; } - if(e_->getOption()->getAsInt(PREF_RPC_MAX_REQUEST_SIZE) < - httpServer_->getContentLength()) { - A2_LOG_INFO - (fmt("Request too long. ContentLength=%lld." - " See --rpc-max-request-size option to loose" - " this limitation.", - static_cast(httpServer_->getContentLength()))); + const std::string& upgradeHd = header->find("upgrade"); + const std::string& connectionHd = header->find("connection"); + if(httpServer_->getRequestPath() == "/jsonrpc" && + httpServer_->getMethod() == "GET" && + util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") && + util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade") && + header->find("sec-websocket-version") == "13" && + header->defined("sec-websocket-key")) { + std::string serverKey = + createWebSocketServerKey(header->find("sec-websocket-key")); + httpServer_->feedUpgradeResponse("websocket", + fmt("Sec-WebSocket-Accept: %s\r\n", + serverKey.c_str())); + Command* command = + new rpc::WebSocketResponseCommand(getCuid(), httpServer_, e_, + socket_); + e_->addCommand(command); + e_->setNoWait(true); + return true; + } else { + if(e_->getOption()->getAsInt(PREF_RPC_MAX_REQUEST_SIZE) < + httpServer_->getContentLength()) { + A2_LOG_INFO + (fmt("Request too long. ContentLength=%lld." + " See --rpc-max-request-size option to loose" + " this limitation.", + static_cast(httpServer_->getContentLength()))); + return true; + } + Command* command = new HttpServerBodyCommand(getCuid(), httpServer_, e_, + socket_); + e_->addCommand(command); + e_->setNoWait(true); return true; } - Command* command = new HttpServerBodyCommand(getCuid(), httpServer_, e_, - socket_); - e_->addCommand(command); - e_->setNoWait(true); - return true; } else { if(timeoutTimer_.difference(global::wallclock()) >= 30) { A2_LOG_INFO("HTTP request timeout."); diff --git a/src/HttpServerResponseCommand.cc b/src/HttpServerResponseCommand.cc index a44560e7..eabcf22b 100644 --- a/src/HttpServerResponseCommand.cc +++ b/src/HttpServerResponseCommand.cc @@ -40,9 +40,6 @@ #include "LogFactory.h" #include "HttpServerCommand.h" #include "RequestGroupMan.h" -#include "RecoverableException.h" -#include "wallclock.h" -#include "util.h" #include "fmt.h" namespace aria2 { @@ -52,54 +49,22 @@ HttpServerResponseCommand::HttpServerResponseCommand const SharedHandle& httpServer, DownloadEngine* e, const SharedHandle& socket) - : Command(cuid), - e_(e), - socket_(socket), - httpServer_(httpServer) -{ - setStatus(Command::STATUS_ONESHOT_REALTIME); - e_->addSocketForWriteCheck(socket_, this); -} + : AbstractHttpServerResponseCommand(cuid, httpServer, e, socket) +{} HttpServerResponseCommand::~HttpServerResponseCommand() -{ - e_->deleteSocketForWriteCheck(socket_, this); -} +{} -bool HttpServerResponseCommand::execute() +void HttpServerResponseCommand::afterSend +(const SharedHandle& httpServer, + DownloadEngine* e) { - if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { - return true; - } - try { - httpServer_->sendResponse(); - } catch(RecoverableException& e) { - A2_LOG_INFO_EX - (fmt("CUID#%lld - Error occurred while transmitting response body.", - getCuid()), - e); - return true; - } - if(httpServer_->sendBufferIsEmpty()) { - A2_LOG_INFO(fmt("CUID#%lld - HttpServer: all response transmitted.", + if(httpServer->supportsPersistentConnection()) { + A2_LOG_INFO(fmt("CUID#%lld - Persist connection.", getCuid())); - if(httpServer_->supportsPersistentConnection()) { - A2_LOG_INFO(fmt("CUID#%lld - Persist connection.", - getCuid())); - e_->addCommand - (new HttpServerCommand(getCuid(), httpServer_, e_, socket_)); - } - return true; - } else { - if(timeoutTimer_.difference(global::wallclock()) >= 10) { - A2_LOG_INFO(fmt("CUID#%lld - HttpServer: Timeout while trasmitting" - " response.", - getCuid())); - return true; - } else { - e_->addCommand(this); - return false; - } + e->addCommand + (new HttpServerCommand(getCuid(), httpServer, e, + httpServer->getSocket())); } } diff --git a/src/HttpServerResponseCommand.h b/src/HttpServerResponseCommand.h index 12bdf434..670dc286 100644 --- a/src/HttpServerResponseCommand.h +++ b/src/HttpServerResponseCommand.h @@ -35,22 +35,14 @@ #ifndef D_HTTP_SERVER_RESPONSE_COMMAND_H #define D_HTTP_SERVER_RESPONSE_COMMAND_H -#include "Command.h" -#include "SharedHandle.h" -#include "TimerA2.h" +#include "AbstractHttpServerResponseCommand.h" namespace aria2 { -class DownloadEngine; -class SocketCore; -class HttpServer; - -class HttpServerResponseCommand : public Command { -private: - DownloadEngine* e_; - SharedHandle socket_; - SharedHandle httpServer_; - Timer timeoutTimer_; +class HttpServerResponseCommand : public AbstractHttpServerResponseCommand { +protected: + virtual void afterSend(const SharedHandle& httpServer, + DownloadEngine* e); public: HttpServerResponseCommand(cuid_t cuid, const SharedHandle& httpServer, @@ -58,8 +50,6 @@ public: const SharedHandle& socket); virtual ~HttpServerResponseCommand(); - - virtual bool execute(); }; } // namespace aria2 diff --git a/src/Makefile.am b/src/Makefile.am index 359365e5..39461b11 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -208,6 +208,8 @@ SRCS = Socket.h\ HttpListenCommand.cc HttpListenCommand.h\ HttpServerCommand.cc HttpServerCommand.h\ HttpServerResponseCommand.cc HttpServerResponseCommand.h\ + AbstractHttpServerResponseCommand.cc \ + AbstractHttpServerResponseCommand.h \ HttpServer.cc HttpServer.h\ StreamPieceSelector.h\ DefaultStreamPieceSelector.cc DefaultStreamPieceSelector.h\ @@ -224,7 +226,12 @@ SRCS = Socket.h\ paramed_string.cc paramed_string.h\ rpc_helper.cc rpc_helper.h\ WatchProcessCommand.cc WatchProcessCommand.h\ - UnknownOptionException.cc UnknownOptionException.h + UnknownOptionException.cc UnknownOptionException.h\ + WebSocketSession.cc WebSocketSession.h\ + WebSocketSessionMan.cc WebSocketSessionMan.h\ + WebSocketResponseCommand.cc WebSocketResponseCommand.h\ + WebSocketInteractionCommand.cc WebSocketInteractionCommand.h\ + Notifier.cc Notifier.h if MINGW_BUILD SRCS += WinConsoleFile.cc WinConsoleFile.h diff --git a/src/MultiUrlRequestInfo.cc b/src/MultiUrlRequestInfo.cc index a754c1c6..0053a99d 100644 --- a/src/MultiUrlRequestInfo.cc +++ b/src/MultiUrlRequestInfo.cc @@ -61,6 +61,9 @@ #include "SocketCore.h" #include "OutputFile.h" #include "UriListParser.h" +#include "SingletonHolder.h" +#include "Notifier.h" +#include "WebSocketSessionMan.h" #ifdef ENABLE_SSL # include "TLSContext.h" #endif // ENABLE_SSL @@ -167,6 +170,13 @@ error_code::Value MultiUrlRequestInfo::execute() { error_code::Value returnValue = error_code::FINISHED; try { + SharedHandle wsSessionMan; + if(option_->getAsBool(PREF_ENABLE_RPC)) { + wsSessionMan.reset(new rpc::WebSocketSessionMan()); + } + Notifier notifier(wsSessionMan); + SingletonHolder::instance(¬ifier); + DownloadEngineHandle e = DownloadEngineFactory().newDownloadEngine(option_.get(), requestGroups_); @@ -290,6 +300,7 @@ error_code::Value MultiUrlRequestInfo::execute() } A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e); } + SingletonHolder::instance(0); #ifdef SIGHUP util::setGlobalSignalHandler(SIGHUP, SIG_DFL, 0); #endif // SIGHUP diff --git a/src/Notifier.cc b/src/Notifier.cc new file mode 100644 index 00000000..8173db05 --- /dev/null +++ b/src/Notifier.cc @@ -0,0 +1,77 @@ +/* */ +#include "Notifier.h" +#include "RequestGroup.h" +#include "WebSocketSessionMan.h" +#include "LogFactory.h" + +namespace aria2 { + +Notifier::Notifier(const SharedHandle& wsSessionMan) + : wsSessionMan_(wsSessionMan) +{} + +Notifier::~Notifier() {} + +void Notifier::addWebSocketSession +(const SharedHandle& wsSession) +{ + A2_LOG_DEBUG("WebSocket session added."); + wsSessionMan_->addSession(wsSession); +} + +void Notifier::removeWebSocketSession +(const SharedHandle& wsSession) +{ + A2_LOG_DEBUG("WebSocket session removed."); + wsSessionMan_->removeSession(wsSession); +} + +const std::string Notifier::ON_DOWNLOAD_START = "onDownloadStart"; +const std::string Notifier::ON_DOWNLOAD_PAUSE = "onDownloadPause"; +const std::string Notifier::ON_DOWNLOAD_STOP = "onDownloadStop"; +const std::string Notifier::ON_DOWNLOAD_COMPLETE = "onDownloadComplete"; +const std::string Notifier::ON_DOWNLOAD_ERROR = "onDownloadError"; +const std::string Notifier::ON_BT_DOWNLOAD_COMPLETE = "onBtDownloadComplete"; + +void Notifier::notifyDownloadEvent +(const std::string& event, const RequestGroup* group) +{ + if(wsSessionMan_) { + wsSessionMan_->addNotification(event, group); + } +} + +} // namespace aria2 diff --git a/src/Notifier.h b/src/Notifier.h new file mode 100644 index 00000000..78ae89cf --- /dev/null +++ b/src/Notifier.h @@ -0,0 +1,82 @@ +/* */ +#ifndef D_NOTIFIER_H +#define D_NOTIFIER_H + +#include "common.h" +#include "SharedHandle.h" + +namespace aria2 { + +class RequestGroup; +class Option; +class Pref; + +namespace rpc { + +class WebSocketSessionMan; +class WebSocketSession; + +} // namespace rpc + +class Notifier { +public: + // The string constants for download events. + static const std::string ON_DOWNLOAD_START; + static const std::string ON_DOWNLOAD_PAUSE; + static const std::string ON_DOWNLOAD_STOP; + static const std::string ON_DOWNLOAD_COMPLETE; + static const std::string ON_DOWNLOAD_ERROR; + static const std::string ON_BT_DOWNLOAD_COMPLETE; + + Notifier(const SharedHandle& wsSessionMan); + ~Notifier(); + void addWebSocketSession(const SharedHandle& wsSes); + void removeWebSocketSession(const SharedHandle& wsSes); + // Notifies the download event to all listeners. + void notifyDownloadEvent(const std::string& event, const RequestGroup* group); + + void notifyDownloadEvent(const std::string& event, + const SharedHandle& group) + { + notifyDownloadEvent(event, group.get()); + } +private: + SharedHandle wsSessionMan_; +}; + +} // namespace aria2 + +#endif // D_NOTIFIER_H diff --git a/src/RequestGroupMan.cc b/src/RequestGroupMan.cc index 9f736952..241f72d1 100644 --- a/src/RequestGroupMan.cc +++ b/src/RequestGroupMan.cc @@ -76,6 +76,8 @@ #include "OutputFile.h" #include "download_helper.h" #include "UriListParser.h" +#include "SingletonHolder.h" +#include "Notifier.h" namespace aria2 { @@ -252,6 +254,16 @@ bool RequestGroupMan::removeReservedGroup(a2_gid_t gid) namespace { +void notifyDownloadEvent +(const std::string& event, const SharedHandle& group) +{ + SingletonHolder::instance()->notifyDownloadEvent(event, group); +} + +} // namespace + +namespace { + void executeStopHook (const SharedHandle& group, const Option* option, @@ -267,6 +279,14 @@ void executeStopHook } else if(!option->blank(PREF_ON_DOWNLOAD_STOP)) { util::executeHookByOptName(group, option, PREF_ON_DOWNLOAD_STOP); } + if(result == error_code::FINISHED) { + notifyDownloadEvent(Notifier::ON_DOWNLOAD_COMPLETE, group); + } else if(result != error_code::IN_PROGRESS && + result != error_code::REMOVED) { + notifyDownloadEvent(Notifier::ON_DOWNLOAD_ERROR, group); + } else { + notifyDownloadEvent(Notifier::ON_DOWNLOAD_STOP, group); + } } } // namespace @@ -353,8 +373,9 @@ public: reservedGroups_.push_front(group); group->releaseRuntimeResource(e_); group->setForceHaltRequested(false); - util::executeHookByOptName - (group, e_->getOption(), PREF_ON_DOWNLOAD_PAUSE); + util::executeHookByOptName(group, e_->getOption(), + PREF_ON_DOWNLOAD_PAUSE); + notifyDownloadEvent(Notifier::ON_DOWNLOAD_PAUSE, group); // TODO Should we have to prepend spend uris to remaining uris // in case PREF_REUSE_URI is disabed? } else { @@ -536,8 +557,9 @@ void RequestGroupMan::fillRequestGroupFromReserver(DownloadEngine* e) requestGroups_.push_back(groupToAdd); requestQueueCheck(); } - util::executeHookByOptName - (groupToAdd, e->getOption(), PREF_ON_DOWNLOAD_START); + util::executeHookByOptName(groupToAdd, e->getOption(), + PREF_ON_DOWNLOAD_START); + notifyDownloadEvent(Notifier::ON_DOWNLOAD_START, groupToAdd); } if(!temp.empty()) { reservedGroups_.insert(reservedGroups_.begin(), temp.begin(), temp.end()); diff --git a/src/SingletonHolder.h b/src/SingletonHolder.h index e6664bd0..7f8e2d69 100644 --- a/src/SingletonHolder.h +++ b/src/SingletonHolder.h @@ -40,25 +40,25 @@ namespace aria2 { template class SingletonHolder { private: - static T instance_; + static T* instance_; SingletonHolder() {} public: ~SingletonHolder() {} - static T& instance() + static T* instance() { return instance_; } - static void instance(T& instance) + static void instance(T* instance) { instance_ = instance; } }; template -T SingletonHolder::instance_; +T* SingletonHolder::instance_ = 0; } // namespace aria2 diff --git a/src/WebSocketInteractionCommand.cc b/src/WebSocketInteractionCommand.cc new file mode 100644 index 00000000..4daf1d25 --- /dev/null +++ b/src/WebSocketInteractionCommand.cc @@ -0,0 +1,111 @@ +/* */ +#include "WebSocketInteractionCommand.h" +#include "SocketCore.h" +#include "DownloadEngine.h" +#include "RequestGroupMan.h" +#include "WebSocketSession.h" +#include "Logger.h" +#include "LogFactory.h" +#include "fmt.h" +#include "SingletonHolder.h" +#include "Notifier.h" + +namespace aria2 { + +namespace rpc { + +WebSocketInteractionCommand::WebSocketInteractionCommand +(cuid_t cuid, + const SharedHandle& wsSession, + DownloadEngine* e, + const SharedHandle& socket) + : Command(cuid), + e_(e), + socket_(socket), + writeCheck_(false), + wsSession_(wsSession) +{ + SingletonHolder::instance()->addWebSocketSession(wsSession_); + e_->addSocketForReadCheck(socket_, this); +} + +WebSocketInteractionCommand::~WebSocketInteractionCommand() +{ + e_->deleteSocketForReadCheck(socket_, this); + if(writeCheck_) { + e_->deleteSocketForWriteCheck(socket_, this); + } + SingletonHolder::instance()->removeWebSocketSession(wsSession_); +} + +void WebSocketInteractionCommand::updateWriteCheck() +{ + if(wsSession_->wantWrite()) { + if(!writeCheck_) { + writeCheck_ = true; + e_->addSocketForWriteCheck(socket_, this); + } + } else if(writeCheck_) { + writeCheck_ = false; + e_->deleteSocketForWriteCheck(socket_, this); + } +} + +bool WebSocketInteractionCommand::execute() +{ + if(e_->isHaltRequested()) { + return true; + } + if(wsSession_->onReadEvent() == -1 || wsSession_->onWriteEvent() == -1) { + if(wsSession_->closeSent() || wsSession_->closeReceived()) { + A2_LOG_INFO(fmt("CUID#%lld - WebSocket session terminated.", getCuid())); + } else { + A2_LOG_INFO(fmt("CUID#%lld - WebSocket session terminated" + " (Possibly due to EOF).", getCuid())); + } + return true; + } + if(wsSession_->finish()) { + return true; + } + updateWriteCheck(); + e_->addCommand(this); + return false; +} + +} // namespace rpc + +} // namespace aria2 diff --git a/src/WebSocketInteractionCommand.h b/src/WebSocketInteractionCommand.h new file mode 100644 index 00000000..20f9896c --- /dev/null +++ b/src/WebSocketInteractionCommand.h @@ -0,0 +1,73 @@ +/* */ +#ifndef D_WEB_SOCKET_INTERACTION_COMMAND_H +#define D_WEB_SOCKET_INTERACTION_COMMAND_H + +#include "Command.h" +#include "SharedHandle.h" + +namespace aria2 { + +class DownloadEngine; +class SocketCore; + +namespace rpc { + +class WebSocketSession; + +class WebSocketInteractionCommand : public Command { +private: + DownloadEngine* e_; + SharedHandle socket_; + bool writeCheck_; + SharedHandle wsSession_; +public: + WebSocketInteractionCommand(cuid_t cuid, + const SharedHandle& wsSession, + DownloadEngine* e, + const SharedHandle& socket); + + virtual ~WebSocketInteractionCommand(); + + virtual bool execute(); + + void updateWriteCheck(); +}; + +} // namespace rpc + +} // namespace aria2 + +#endif // D_WEB_SOCKET_INTERACTION_COMMAND_H diff --git a/src/WebSocketResponseCommand.cc b/src/WebSocketResponseCommand.cc new file mode 100644 index 00000000..ba6add52 --- /dev/null +++ b/src/WebSocketResponseCommand.cc @@ -0,0 +1,72 @@ +/* */ +#include "WebSocketResponseCommand.h" +#include "SocketCore.h" +#include "DownloadEngine.h" +#include "HttpServer.h" +#include "WebSocketSession.h" +#include "WebSocketInteractionCommand.h" + +namespace aria2 { + +namespace rpc { + +WebSocketResponseCommand::WebSocketResponseCommand +(cuid_t cuid, + const SharedHandle& httpServer, + DownloadEngine* e, + const SharedHandle& socket) + : AbstractHttpServerResponseCommand(cuid, httpServer, e, socket) +{} + +WebSocketResponseCommand::~WebSocketResponseCommand() +{} + +void WebSocketResponseCommand::afterSend +(const SharedHandle& httpServer, + DownloadEngine* e) +{ + SharedHandle wsSession + (new WebSocketSession(httpServer->getSocket(), getDownloadEngine())); + WebSocketInteractionCommand* command = + new WebSocketInteractionCommand(getCuid(), wsSession, e, + wsSession->getSocket()); + wsSession->setCommand(command); + e->addCommand(command); +} + +} // namespace rpc + +} // namespace aria2 diff --git a/src/WebSocketResponseCommand.h b/src/WebSocketResponseCommand.h new file mode 100644 index 00000000..8a029686 --- /dev/null +++ b/src/WebSocketResponseCommand.h @@ -0,0 +1,61 @@ +/* */ +#ifndef D_WEB_SOCKET_RESPONSE_COMMAND_H +#define D_WEB_SOCKET_RESPONSE_COMMAND_H + +#include "AbstractHttpServerResponseCommand.h" + +namespace aria2 { + +namespace rpc { + +class WebSocketResponseCommand : public AbstractHttpServerResponseCommand { +protected: + virtual void afterSend(const SharedHandle& httpServer, + DownloadEngine* e); +public: + WebSocketResponseCommand(cuid_t cuid, + const SharedHandle& httpServer, + DownloadEngine* e, + const SharedHandle& socket); + + virtual ~WebSocketResponseCommand(); +}; + +} // namespace rpc + +} // namespace aria2 + +#endif // D_WEB_SOCKET_RESPONSE_COMMAND_H diff --git a/src/WebSocketSession.cc b/src/WebSocketSession.cc new file mode 100644 index 00000000..5fa416e0 --- /dev/null +++ b/src/WebSocketSession.cc @@ -0,0 +1,251 @@ +/* */ +#include "WebSocketSession.h" + +#include +#include + +#include "SocketCore.h" +#include "LogFactory.h" +#include "RecoverableException.h" +#include "message.h" +#include "DownloadEngine.h" +#include "rpc_helper.h" +#include "RpcResponse.h" +#include "json.h" + +namespace aria2 { + +namespace rpc { + +namespace { +ssize_t sendCallback(wslay_event_context_ptr wsctx, + const uint8_t* data, size_t len, int flags, + void* userData) +{ + WebSocketSession* session = reinterpret_cast(userData); + const SharedHandle& socket = session->getSocket(); + try { + ssize_t r = socket->writeData(data, len); + if(r == 0) { + if(socket->wantRead() || socket->wantWrite()) { + wslay_event_set_error(wsctx, WSLAY_ERR_WOULDBLOCK); + } else { + wslay_event_set_error(wsctx, WSLAY_ERR_CALLBACK_FAILURE); + } + r = -1; + } + return r; + } catch(RecoverableException& e) { + A2_LOG_DEBUG_EX(EX_EXCEPTION_CAUGHT, e); + wslay_event_set_error(wsctx, WSLAY_ERR_CALLBACK_FAILURE); + return -1; + } +} +} // namespace + +namespace { +ssize_t recvCallback(wslay_event_context_ptr wsctx, + uint8_t* buf, size_t len, int flags, + void* userData) +{ + WebSocketSession* session = reinterpret_cast(userData); + const SharedHandle& socket = session->getSocket(); + try { + ssize_t r; + socket->readData(buf, len); + if(len == 0) { + if(socket->wantRead() || socket->wantWrite()) { + wslay_event_set_error(wsctx, WSLAY_ERR_WOULDBLOCK); + } else { + wslay_event_set_error(wsctx, WSLAY_ERR_CALLBACK_FAILURE); + } + r = -1; + } else { + r = len; + } + return r; + } catch(RecoverableException& e) { + A2_LOG_DEBUG_EX(EX_EXCEPTION_CAUGHT, e); + wslay_event_set_error(wsctx, WSLAY_ERR_CALLBACK_FAILURE); + return -1; + } +} +} // namespace + +namespace { +void addResponse(WebSocketSession* wsSession, const RpcResponse& res) +{ + std::string response = toJson(res, "", false); + wsSession->addTextMessage(response); +} +} // namespace + +namespace { +void addResponse(WebSocketSession* wsSession, + const std::vector& results) +{ + std::string response = toJsonBatch(results, "", false); + wsSession->addTextMessage(response); +} +} // namespace + +namespace { +void onMsgRecvCallback(wslay_event_context_ptr wsctx, + const struct wslay_event_on_msg_recv_arg* arg, + void* userData) +{ + WebSocketSession* wsSession = reinterpret_cast(userData); + if(!wslay_is_ctrl_frame(arg->opcode)) { + // TODO Only process text frame + SharedHandle json; + try { + json = json::decode(arg->msg, arg->msg_length); + } catch(RecoverableException& e) { + A2_LOG_INFO_EX("Failed to parse JSON-RPC request", e); + RpcResponse res + (createJsonRpcErrorResponse(-32700, "Parse error.", Null::g())); + addResponse(wsSession, res); + return; + } + const Dict* jsondict = downcast(json); + if(jsondict) { + RpcResponse res = processJsonRpcRequest(jsondict, + wsSession->getDownloadEngine()); + addResponse(wsSession, res); + } else { + const List* jsonlist = downcast(json); + if(jsonlist) { + // This is batch call + std::vector results; + for(List::ValueType::const_iterator i = jsonlist->begin(), + eoi = jsonlist->end(); i != eoi; ++i) { + const Dict* jsondict = downcast(*i); + if(jsondict) { + RpcResponse r = processJsonRpcRequest + (jsondict, wsSession->getDownloadEngine()); + results.push_back(r); + } + } + addResponse(wsSession, results); + } else { + RpcResponse res(createJsonRpcErrorResponse + (-32600, "Invalid Request.", Null::g())); + addResponse(wsSession, res); + } + } + } else { + RpcResponse res(createJsonRpcErrorResponse + (-32600, "Invalid Request.", Null::g())); + addResponse(wsSession, res); + } +} +} // namespace + +WebSocketSession::WebSocketSession(const SharedHandle& socket, + DownloadEngine* e) + : socket_(socket), + e_(e) +{ + wslay_event_callbacks callbacks; + memset(&callbacks, 0, sizeof(wslay_event_callbacks)); + callbacks.recv_callback = recvCallback; + callbacks.send_callback = sendCallback; + callbacks.on_msg_recv_callback = onMsgRecvCallback; + int r = wslay_event_context_server_init(&wsctx_, &callbacks, this); + assert(r == 0); +} + +WebSocketSession::~WebSocketSession() +{ + wslay_event_context_free(wsctx_); +} + +bool WebSocketSession::wantRead() +{ + return wslay_event_want_read(wsctx_); +} + +bool WebSocketSession::wantWrite() +{ + return wslay_event_want_write(wsctx_); +} + +bool WebSocketSession::finish() +{ + return !wantRead() && !wantWrite(); +} + +int WebSocketSession::onReadEvent() +{ + if(wslay_event_recv(wsctx_) == 0) { + return 0; + } else { + return -1; + } +} + +int WebSocketSession::onWriteEvent() +{ + if(wslay_event_send(wsctx_) == 0) { + return 0; + } else { + return -1; + } +} + +void WebSocketSession::addTextMessage(const std::string& msg) +{ + // TODO Don't add text message if the size of outbound queue in + // wsctx_ exceeds certain limit. + wslay_event_msg arg = { + WSLAY_TEXT_FRAME, reinterpret_cast(msg.c_str()), msg.size() + }; + wslay_event_queue_msg(wsctx_, &arg); +} + +bool WebSocketSession::closeReceived() +{ + return wslay_event_get_close_received(wsctx_); +} + +bool WebSocketSession::closeSent() +{ + return wslay_event_get_close_sent(wsctx_); +} + +} // namespace rpc + +} // namespace aria2 diff --git a/src/WebSocketSession.h b/src/WebSocketSession.h new file mode 100644 index 00000000..44462197 --- /dev/null +++ b/src/WebSocketSession.h @@ -0,0 +1,111 @@ +/* */ +#ifndef D_WEB_SOCKET_SESSION_H +#define D_WEB_SOCKET_SESSION_H + +#include "common.h" + +#include + +#include "SharedHandle.h" + +namespace aria2 { + +class SocketCore; +class DownloadEngine; + +namespace rpc { + +class WebSocketInteractionCommand; + +class WebSocketSession { +public: + WebSocketSession(const SharedHandle& socket, + DownloadEngine* e); + ~WebSocketSession(); + // Returns true if this session object wants to read data from the + // remote endpoint. + bool wantRead(); + // Returns true if this session object wants to write data to the + // remote endpoint. + bool wantWrite(); + // Returns true if the session ended and the underlying connection + // can be closed. + bool finish(); + // Call this function when data is available to read. This function + // returns 0 if it succeeds, or -1. + int onReadEvent(); + // Call this function when data can be sent without blocking. This + // function returns 0 if it succeeds, or -1. + int onWriteEvent(); + // Adds text message |msg|. The message is queued and will be sent + // in onWriteEvent(). + void addTextMessage(const std::string& msg); + // Returns true if the close frame is received. + bool closeReceived(); + // Returns true if the close frame is sent. + bool closeSent(); + + const SharedHandle& getSocket() const + { + return socket_; + } + + DownloadEngine* getDownloadEngine() + { + return e_; + } + + WebSocketInteractionCommand* getCommand() + { + return command_; + } + + void setCommand(WebSocketInteractionCommand* command) + { + command_ = command; + } +private: + SharedHandle socket_; + DownloadEngine* e_; + wslay_event_context_ptr wsctx_; + WebSocketInteractionCommand* command_; +}; + +} // namespace rpc + +} // namespace aria2 + +#endif // D_WEB_SOCKET_SESSION_H diff --git a/src/WebSocketSessionMan.cc b/src/WebSocketSessionMan.cc new file mode 100644 index 00000000..15f6c935 --- /dev/null +++ b/src/WebSocketSessionMan.cc @@ -0,0 +1,83 @@ +/* */ +#include "WebSocketSessionMan.h" +#include "WebSocketSession.h" +#include "RequestGroup.h" +#include "json.h" +#include "util.h" +#include "WebSocketInteractionCommand.h" + +namespace aria2 { + +namespace rpc { + +WebSocketSessionMan::WebSocketSessionMan() {} + +WebSocketSessionMan::~WebSocketSessionMan() {} + +void WebSocketSessionMan::addSession +(const SharedHandle& wsSession) +{ + sessions_.insert(wsSession); +} + +void WebSocketSessionMan::removeSession +(const SharedHandle& wsSession) +{ + sessions_.erase(wsSession); +} + +void WebSocketSessionMan::addNotification +(const std::string& method, const RequestGroup* group) +{ + SharedHandle dict = Dict::g(); + dict->put("jsonrpc", "2.0"); + dict->put("method", method); + SharedHandle eventSpec = Dict::g(); + eventSpec->put("gid", util::itos(group->getGID())); + SharedHandle params = List::g(); + params->append(eventSpec); + dict->put("params", params); + std::string msg = json::encode(dict); + for(WebSocketSessions::const_iterator i = sessions_.begin(), + eoi = sessions_.end(); i != eoi; ++i) { + (*i)->addTextMessage(msg); + (*i)->getCommand()->updateWriteCheck(); + } +} + +} // namespace rpc + +} // namespace aria2 diff --git a/src/WebSocketSessionMan.h b/src/WebSocketSessionMan.h new file mode 100644 index 00000000..8004b485 --- /dev/null +++ b/src/WebSocketSessionMan.h @@ -0,0 +1,71 @@ +/* */ +#ifndef D_WEB_SOCKET_SESSION_MAN_H +#define D_WEB_SOCKET_SESSION_MAN_H + +#include "common.h" + +#include +#include + +#include "SharedHandle.h" +#include "a2functional.h" + +namespace aria2 { + +class RequestGroup; + +namespace rpc { + +class WebSocketSession; + +class WebSocketSessionMan { +public: + typedef std::set, + RefLess > WebSocketSessions; + WebSocketSessionMan(); + ~WebSocketSessionMan(); + void addSession(const SharedHandle& wsSession); + void removeSession(const SharedHandle& wsSession); + void addNotification(const std::string& method, const RequestGroup* group); +private: + WebSocketSessions sessions_; +}; + +} // namespace rpc + +} // aria2 + +#endif // D_WEB_SOCKET_SESSION_MAN_H diff --git a/src/a2functional.h b/src/a2functional.h index 9b5cab2d..d96b86e2 100644 --- a/src/a2functional.h +++ b/src/a2functional.h @@ -331,6 +331,14 @@ struct DerefEqual derefEqual(const T& t) return DerefEqual(t); } +template +struct RefLess { + bool operator()(const SharedHandle& lhs, const SharedHandle& rhs) const + { + return lhs.get() < rhs.get(); + } +}; + } // namespace aria2 #endif // D_A2_FUNCTIONAL_H diff --git a/src/json.cc b/src/json.cc index 51c75e9d..9094bcbb 100644 --- a/src/json.cc +++ b/src/json.cc @@ -50,11 +50,9 @@ namespace json { // Function prototype declaration namespace { -std::pair, std::string::const_iterator> -decode -(std::string::const_iterator first, - std::string::const_iterator last, - size_t depth); +template +std::pair, InputIterator> +decode(InputIterator first, InputIterator last, size_t depth); } // namespace namespace { @@ -64,9 +62,8 @@ const size_t MAX_STRUCTURE_DEPTH = 100; } // namespace namespace { -std::string::const_iterator skipWs -(std::string::const_iterator first, - std::string::const_iterator last) +template +InputIterator skipWs(InputIterator first, InputIterator last) { while(first != last && std::find(vbegin(WS), vend(WS), *first) != vend(WS)) { ++first; @@ -76,9 +73,8 @@ std::string::const_iterator skipWs } // namespace namespace { -void checkEof -(std::string::const_iterator first, - std::string::const_iterator last) +template +void checkEof(InputIterator first, InputIterator last) { if(first == last) { throw DL_ABORT_EX2("JSON decoding failed: unexpected EOF", @@ -88,11 +84,10 @@ void checkEof } // namespace namespace { -std::string::const_iterator -decodeKeyword -(std::string::const_iterator first, - std::string::const_iterator last, - const std::string& keyword) +template +InputIterator +decodeKeyword(InputIterator first, InputIterator last, + const std::string& keyword) { size_t len = keyword.size(); for(size_t i = 0; i < len; ++i) { @@ -109,10 +104,9 @@ decodeKeyword } // namespace namespace { -std::pair, std::string::const_iterator> -decodeTrue -(std::string::const_iterator first, - std::string::const_iterator last) +template +std::pair, InputIterator> +decodeTrue(InputIterator first, InputIterator last) { first = decodeKeyword(first, last, "true"); return std::make_pair(Bool::gTrue(), first); @@ -120,10 +114,9 @@ decodeTrue } // namespace namespace { -std::pair, std::string::const_iterator> -decodeFalse -(std::string::const_iterator first, - std::string::const_iterator last) +template +std::pair, InputIterator> +decodeFalse(InputIterator first, InputIterator last) { first = decodeKeyword(first, last, "false"); return std::make_pair(Bool::gFalse(), first); @@ -131,10 +124,9 @@ decodeFalse } // namespace namespace { -std::pair, std::string::const_iterator> -decodeNull -(std::string::const_iterator first, - std::string::const_iterator last) +template +std::pair, InputIterator> +decodeNull(InputIterator first, InputIterator last) { first = decodeKeyword(first, last, "null"); return std::make_pair(Null::g(), first); @@ -142,15 +134,14 @@ decodeNull } // namespace namespace { -std::pair, std::string::const_iterator> -decodeString -(std::string::const_iterator first, - std::string::const_iterator last) +template +std::pair, InputIterator> +decodeString(InputIterator first, InputIterator last) { // Consume first char, assuming it is '"'. ++first; std::string s; - std::string::const_iterator offset = first; + InputIterator offset = first; while(first != last) { if(*first == '"') { break; @@ -161,7 +152,7 @@ decodeString checkEof(first, last); if(*first == 'u') { ++first; - std::string::const_iterator uchars = first; + InputIterator uchars = first; for(int i = 0; i < 4; ++i, ++first) { checkEof(first, last); } @@ -182,7 +173,7 @@ decodeString error_code::JSON_PARSE_ERROR); } first += 2; - std::string::const_iterator uchars = first; + InputIterator uchars = first; for(int i = 0; i < 4; ++i, ++first) { checkEof(first, last); } @@ -242,9 +233,8 @@ decodeString } // namespace namespace { -void checkEmptyDigit -(std::string::const_iterator first, - std::string::const_iterator last) +template +void checkEmptyDigit(InputIterator first, InputIterator last) { if(std::distance(first, last) == 0) { throw DL_ABORT_EX2("JSON decoding failed: zero DIGIT.", @@ -254,9 +244,8 @@ void checkEmptyDigit } // namespace namespace { -void checkLeadingZero -(std::string::const_iterator first, - std::string::const_iterator last) +template +void checkLeadingZero(InputIterator first, InputIterator last) { if(std::distance(first, last) > 2 && *first == '0') { throw DL_ABORT_EX2("JSON decoding failed: leading zero.", @@ -266,17 +255,16 @@ void checkLeadingZero } // namespace namespace { -std::pair, std::string::const_iterator> -decodeNumber -(std::string::const_iterator first, - std::string::const_iterator last) +template +std::pair, InputIterator> +decodeNumber(InputIterator first, InputIterator last) { std::string s; if(*first == '-') { s.append(first, first+1); ++first; } - std::string::const_iterator offset = first; + InputIterator offset = first; while(first != last && in(*first, '0', '9')) { ++first; } @@ -336,11 +324,9 @@ void checkDepth(size_t depth) } // namespace namespace { -std::pair, std::string::const_iterator> -decodeArray -(std::string::const_iterator first, - std::string::const_iterator last, - size_t depth) +template +std::pair, InputIterator> +decodeArray(InputIterator first, InputIterator last, size_t depth) { checkDepth(depth); SharedHandle list = List::g(); @@ -350,7 +336,7 @@ decodeArray checkEof(first, last); if(*first != ']') { while(1) { - std::pair, std::string::const_iterator> + std::pair, InputIterator> r = decode(first, last, depth); list->append(r.first); first = r.second; @@ -372,11 +358,9 @@ decodeArray } // namespace namespace { -std::pair, std::string::const_iterator> -decodeObject -(std::string::const_iterator first, - std::string::const_iterator last, - size_t depth) +template +std::pair, InputIterator> +decodeObject(InputIterator first, InputIterator last, size_t depth) { checkDepth(depth); SharedHandle dict = Dict::g(); @@ -386,7 +370,7 @@ decodeObject checkEof(first, last); if(*first != '}') { while(1) { - std::pair, std::string::const_iterator> + std::pair, InputIterator> keyRet = decodeString(first, last); first = keyRet.second; first = skipWs(first, last); @@ -396,7 +380,7 @@ decodeObject error_code::JSON_PARSE_ERROR); } ++first; - std::pair, std::string::const_iterator> + std::pair, InputIterator> valueRet = decode(first, last, depth); dict->put(downcast(keyRet.first)->s(), valueRet.first); first = valueRet.second; @@ -420,11 +404,9 @@ decodeObject } // namespace namespace { -std::pair, std::string::const_iterator> -decode -(std::string::const_iterator first, - std::string::const_iterator last, - size_t depth) +template +std::pair, InputIterator> +decode(InputIterator first, InputIterator last, size_t depth) { first = skipWs(first, last); if(first == last) { @@ -454,17 +436,16 @@ decode } } // namespace -SharedHandle decode(const std::string& json) +template +SharedHandle decode(InputIterator first, InputIterator last) { - std::string::const_iterator first = json.begin(); - std::string::const_iterator last = json.end(); first = skipWs(first, last); if(first == last) { throw DL_ABORT_EX2("JSON decoding failed:" " Unexpected EOF in term context.", error_code::JSON_PARSE_ERROR); } - std::pair, std::string::const_iterator> r; + std::pair, InputIterator> r; if(*first == '[') { r = decodeArray(first, last, 1); } else if(*first == '{') { @@ -477,6 +458,16 @@ SharedHandle decode(const std::string& json) return r.first; } +SharedHandle decode(const std::string& json) +{ + return decode(json.begin(), json.end()); +} + +SharedHandle decode(const unsigned char* json, size_t len) +{ + return decode(json, json+len); +} + std::string jsonEscape(const std::string& s) { std::string t; diff --git a/src/json.h b/src/json.h index 3a106fdd..7538a2f2 100644 --- a/src/json.h +++ b/src/json.h @@ -45,6 +45,8 @@ namespace json { // Parses JSON text defined in RFC4627. SharedHandle decode(const std::string& json); +SharedHandle decode(const unsigned char* json, size_t len); + std::string jsonEscape(const std::string& s); template diff --git a/src/rpc_helper.cc b/src/rpc_helper.cc index a9d6e721..054c7fba 100644 --- a/src/rpc_helper.cc +++ b/src/rpc_helper.cc @@ -38,6 +38,12 @@ #include "XmlRpcRequestParserStateMachine.h" #include "message.h" #include "DlAbortEx.h" +#include "DownloadEngine.h" +#include "RpcMethod.h" +#include "RpcResponse.h" +#include "RpcMethodFactory.h" +#include "LogFactory.h" +#include "fmt.h" namespace aria2 { @@ -60,6 +66,51 @@ RpcRequest xmlParseMemory(const char* xml, size_t size) } #endif // ENABLE_XML_RPC +RpcResponse createJsonRpcErrorResponse(int code, + const std::string& msg, + const SharedHandle& id) +{ + SharedHandle params = Dict::g(); + params->put("code", Integer::g(code)); + params->put("message", msg); + rpc::RpcResponse res(code, params, id); + return res; +} + +RpcResponse processJsonRpcRequest(const Dict* jsondict, DownloadEngine* e) +{ + SharedHandle id = jsondict->get("id"); + if(!id) { + return createJsonRpcErrorResponse(-32600, "Invalid Request.", Null::g()); + } + const String* methodName = downcast(jsondict->get("method")); + if(!methodName) { + return createJsonRpcErrorResponse(-32600, "Invalid Request.", id); + } + SharedHandle params; + const SharedHandle& tempParams = jsondict->get("params"); + if(downcast(tempParams)) { + params = static_pointer_cast(tempParams); + } else if(!tempParams) { + params = List::g(); + } else { + // TODO No support for Named params + return createJsonRpcErrorResponse(-32602, "Invalid params.", id); + } + rpc::RpcRequest req(methodName->s(), params, id); + req.jsonRpc = true; + SharedHandle method; + try { + method = rpc::RpcMethodFactory::create(req.methodName); + } catch(RecoverableException& e) { + A2_LOG_INFO_EX(EX_EXCEPTION_CAUGHT, e); + return createJsonRpcErrorResponse(-32601, "Method not found.", id); + } + A2_LOG_INFO(fmt("Executing RPC method %s", req.methodName.c_str())); + rpc::RpcResponse res = method->execute(req, e); + return res; +} + } // namespace rpc } // namespace aria2 diff --git a/src/rpc_helper.h b/src/rpc_helper.h index f4c2ba69..4747a033 100644 --- a/src/rpc_helper.h +++ b/src/rpc_helper.h @@ -38,17 +38,34 @@ #include "common.h" #include +#include + +#include "SharedHandle.h" namespace aria2 { +class ValueBase; +class Dict; +class DownloadEngine; + namespace rpc { struct RpcRequest; +struct RpcResponse; #ifdef ENABLE_XML_RPC RpcRequest xmlParseMemory(const char* xml, size_t size); #endif // ENABLE_XML_RPC +// Creates error response. The |code| is the JSON-RPC error code. The +// |msg| is the error message. The |id| is the id of the request . +RpcResponse createJsonRpcErrorResponse(int code, + const std::string& msg, + const SharedHandle& id); + +// Processes JSON-RPC request |jsondict| and returns the result. +RpcResponse processJsonRpcRequest(const Dict* jsondict, DownloadEngine* e); + } // namespace rpc } // namespace aria2 diff --git a/test/SingletonHolderTest.cc b/test/SingletonHolderTest.cc index b2dc2c84..cb62bf4d 100644 --- a/test/SingletonHolderTest.cc +++ b/test/SingletonHolderTest.cc @@ -1,8 +1,11 @@ #include "SingletonHolder.h" -#include "SharedHandle.h" + #include + #include +#include "SharedHandle.h" + namespace aria2 { class SingletonHolderTest : public CppUnit::TestFixture { @@ -35,26 +38,16 @@ public: } }; -typedef SharedHandle MHandle; -typedef SharedHandle IntHandle; - void SingletonHolderTest::testInstance() { - MHandle m(new M("Hello world.")); - SingletonHolder::instance(m); - - std::cerr << SingletonHolder::instance()->greeting() << std::endl; - - SingletonHolder::instance()->greeting("Yes, it worked!"); - - std::cerr << SingletonHolder::instance()->greeting() << std::endl; - - IntHandle i(new int(100)); - SingletonHolder::instance(i); - std::cerr << *SingletonHolder::instance() << std::endl; - - std::cerr << SingletonHolder::instance()->greeting() << std::endl; + M m("Hello world."); + SingletonHolder::instance(&m); + CPPUNIT_ASSERT_EQUAL(std::string("Hello world."), + SingletonHolder::instance()->greeting()); + SingletonHolder::instance()->greeting("Yes, it worked!"); + CPPUNIT_ASSERT_EQUAL(std::string("Yes, it worked!"), + SingletonHolder::instance()->greeting()); } } // namespace aria2