diff --git a/src/BtAnnounce.h b/src/BtAnnounce.h index 6cdd6804..47523af2 100644 --- a/src/BtAnnounce.h +++ b/src/BtAnnounce.h @@ -40,9 +40,12 @@ #include #include "a2time.h" +#include "SharedHandle.h" namespace aria2 { +class UDPTrackerRequest; + class BtAnnounce { public: virtual ~BtAnnounce() {} @@ -65,6 +68,10 @@ public: */ virtual std::string getAnnounceUrl() = 0; + virtual SharedHandle + createUDPTrackerRequest(const std::string& remoteAddr, uint16_t remotePort, + uint16_t localPort) = 0; + /** * Tells that the announce process has just started. */ @@ -96,6 +103,9 @@ public: virtual void processAnnounceResponse(const unsigned char* trackerResponse, size_t trackerResponseLength) = 0; + virtual void processUDPTrackerResponse + (const SharedHandle& req) = 0; + /** * Returns true if no more announce is needed. */ diff --git a/src/BtRegistry.cc b/src/BtRegistry.cc index 1189a652..273f0130 100644 --- a/src/BtRegistry.cc +++ b/src/BtRegistry.cc @@ -42,12 +42,14 @@ #include "BtProgressInfoFile.h" #include "bittorrent_helper.h" #include "LpdMessageReceiver.h" +#include "UDPTrackerClient.h" #include "NullHandle.h" namespace aria2 { BtRegistry::BtRegistry() - : tcpPort_(0) + : tcpPort_(0), + udpPort_(0) {} BtRegistry::~BtRegistry() {} @@ -107,6 +109,12 @@ void BtRegistry::setLpdMessageReceiver lpdMessageReceiver_ = receiver; } +void BtRegistry::setUDPTrackerClient +(const SharedHandle& tracker) +{ + udpTrackerClient_ = tracker; +} + BtObject::BtObject (const SharedHandle& downloadContext, const SharedHandle& pieceStorage, diff --git a/src/BtRegistry.h b/src/BtRegistry.h index a4791210..06c0bfa4 100644 --- a/src/BtRegistry.h +++ b/src/BtRegistry.h @@ -51,6 +51,7 @@ class BtRuntime; class BtProgressInfoFile; class DownloadContext; class LpdMessageReceiver; +class UDPTrackerClient; struct BtObject { SharedHandle downloadContext; @@ -80,7 +81,10 @@ class BtRegistry { private: std::map > pool_; uint16_t tcpPort_; + // This is IPv4 port for DHT and UDP tracker. No IPv6 udpPort atm. + uint16_t udpPort_; SharedHandle lpdMessageReceiver_; + SharedHandle udpTrackerClient_; public: BtRegistry(); ~BtRegistry(); @@ -118,11 +122,26 @@ public: return tcpPort_; } + void setUdpPort(uint16_t port) + { + udpPort_ = port; + } + uint16_t getUdpPort() const + { + return udpPort_; + } + void setLpdMessageReceiver(const SharedHandle& receiver); const SharedHandle& getLpdMessageReceiver() const { return lpdMessageReceiver_; } + + void setUDPTrackerClient(const SharedHandle& tracker); + const SharedHandle& getUDPTrackerClient() const + { + return udpTrackerClient_; + } }; } // namespace aria2 diff --git a/src/BtSetup.cc b/src/BtSetup.cc index 1a8b6403..ce7f3bf8 100644 --- a/src/BtSetup.cc +++ b/src/BtSetup.cc @@ -67,6 +67,7 @@ #include "DHTMessageReceiver.h" #include "DHTMessageFactory.h" #include "DHTMessageCallback.h" +#include "UDPTrackerClient.h" #include "BtProgressInfoFile.h" #include "BtAnnounce.h" #include "BtRuntime.h" diff --git a/src/DHTInteractionCommand.cc b/src/DHTInteractionCommand.cc index 4fc3f371..096812a8 100644 --- a/src/DHTInteractionCommand.cc +++ b/src/DHTInteractionCommand.cc @@ -46,10 +46,16 @@ #include "LogFactory.h" #include "DHTMessageCallback.h" #include "DHTNode.h" +#include "DHTConnection.h" +#include "UDPTrackerClient.h" +#include "UDPTrackerRequest.h" #include "fmt.h" +#include "wallclock.h" namespace aria2 { +// TODO This name of this command is misleading, because now it also +// handles UDP trackers as well as DHT. DHTInteractionCommand::DHTInteractionCommand(cuid_t cuid, DownloadEngine* e) : Command(cuid), e_(e) @@ -77,23 +83,60 @@ void DHTInteractionCommand::disableReadCheckSocket(const SharedHandlegetRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { + // We need to keep this command alive while TrackerWatcherCommand + // needs this. + if(e_->getRequestGroupMan()->downloadFinished() || + (e_->isHaltRequested() && udpTrackerClient_->getNumWatchers() == 0)) { + return true; + } else if(e_->isForceHaltRequested()) { + udpTrackerClient_->failAll(); return true; } taskQueue_->executeTask(); - while(1) { - SharedHandle m = receiver_->receiveMessage(); - if(!m) { - break; + std::string remoteAddr; + uint16_t remotePort; + unsigned char data[64*1024]; + try { + while(1) { + ssize_t length = connection_->receiveMessage(data, sizeof(data), + remoteAddr, remotePort); + if(length <= 0) { + break; + } + if(data[0] == 'd') { + // udp tracker response does not start with 'd', so assume + // this message belongs to DHT. nothrow. + receiver_->receiveMessage(remoteAddr, remotePort, data, length); + } else { + // this may be udp tracker response. nothrow. + udpTrackerClient_->receiveReply(data, length, remoteAddr, remotePort, + global::wallclock()); + } } + } catch(RecoverableException& e) { + A2_LOG_INFO_EX("Exception thrown while receiving UDP message.", e); } receiver_->handleTimeout(); - try { - dispatcher_->sendMessages(); - } catch(RecoverableException& e) { - A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e); + udpTrackerClient_->handleTimeout(global::wallclock()); + dispatcher_->sendMessages(); + while(!udpTrackerClient_->getPendingRequests().empty()) { + // no throw + ssize_t length = udpTrackerClient_->createRequest(data, sizeof(data), + remoteAddr, remotePort, + global::wallclock()); + if(length == -1) { + break; + } + try { + // throw + connection_->sendMessage(data, length, remoteAddr, remotePort); + udpTrackerClient_->requestSent(global::wallclock()); + } catch(RecoverableException& e) { + A2_LOG_INFO_EX("Exception thrown while sending UDP tracker request.", e); + udpTrackerClient_->requestFail(UDPT_ERR_NETWORK); + } } e_->addCommand(this); return false; @@ -114,4 +157,16 @@ void DHTInteractionCommand::setTaskQueue(const SharedHandle& taskQ taskQueue_ = taskQueue; } +void DHTInteractionCommand::setConnection +(const SharedHandle& connection) +{ + connection_ = connection; +} + +void DHTInteractionCommand::setUDPTrackerClient +(const SharedHandle& udpTrackerClient) +{ + udpTrackerClient_ = udpTrackerClient; +} + } // namespace aria2 diff --git a/src/DHTInteractionCommand.h b/src/DHTInteractionCommand.h index f00df78a..e90a1dd6 100644 --- a/src/DHTInteractionCommand.h +++ b/src/DHTInteractionCommand.h @@ -45,6 +45,8 @@ class DHTMessageReceiver; class DHTTaskQueue; class DownloadEngine; class SocketCore; +class DHTConnection; +class UDPTrackerClient; class DHTInteractionCommand:public Command { private: @@ -53,6 +55,8 @@ private: SharedHandle receiver_; SharedHandle taskQueue_; SharedHandle readCheckSocket_; + SharedHandle connection_; + SharedHandle udpTrackerClient_; public: DHTInteractionCommand(cuid_t cuid, DownloadEngine* e); @@ -69,6 +73,11 @@ public: void setMessageReceiver(const SharedHandle& receiver); void setTaskQueue(const SharedHandle& taskQueue); + + void setConnection(const SharedHandle& connection); + + void setUDPTrackerClient + (const SharedHandle& udpTrackerClient); }; } // namespace aria2 diff --git a/src/DHTMessageReceiver.cc b/src/DHTMessageReceiver.cc index e5558d9d..36e40639 100644 --- a/src/DHTMessageReceiver.cc +++ b/src/DHTMessageReceiver.cc @@ -63,18 +63,11 @@ DHTMessageReceiver::DHTMessageReceiver DHTMessageReceiver::~DHTMessageReceiver() {} -SharedHandle DHTMessageReceiver::receiveMessage() +SharedHandle DHTMessageReceiver::receiveMessage +(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data, + size_t length) { - std::string remoteAddr; - uint16_t remotePort; - unsigned char data[64*1024]; try { - ssize_t length = connection_->receiveMessage(data, sizeof(data), - remoteAddr, - remotePort); - if(length <= 0) { - return SharedHandle(); - } bool isReply = false; SharedHandle decoded = bencode2::decode(data, length); const Dict* dict = downcast(decoded); @@ -87,13 +80,13 @@ SharedHandle DHTMessageReceiver::receiveMessage() } else { A2_LOG_INFO(fmt("Malformed DHT message. Missing 'y' key. From:%s:%u", remoteAddr.c_str(), remotePort)); - return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + return handleUnknownMessage(data, length, remoteAddr, remotePort); } } else { A2_LOG_INFO(fmt("Malformed DHT message. This is not a bencoded directory." " From:%s:%u", remoteAddr.c_str(), remotePort)); - return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + return handleUnknownMessage(data, length, remoteAddr, remotePort); } if(isReply) { std::pair, @@ -101,7 +94,7 @@ SharedHandle DHTMessageReceiver::receiveMessage() tracker_->messageArrived(dict, remoteAddr, remotePort); if(!p.first) { // timeout or malicious? message - return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + return handleUnknownMessage(data, length, remoteAddr, remotePort); } onMessageReceived(p.first); if(p.second) { @@ -114,14 +107,14 @@ SharedHandle DHTMessageReceiver::receiveMessage() if(*message->getLocalNode() == *message->getRemoteNode()) { // drop message from localnode A2_LOG_INFO("Received DHT message from localnode."); - return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + return handleUnknownMessage(data, length, remoteAddr, remotePort); } onMessageReceived(message); return message; } } catch(RecoverableException& e) { A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e); - return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + return handleUnknownMessage(data, length, remoteAddr, remotePort); } } diff --git a/src/DHTMessageReceiver.h b/src/DHTMessageReceiver.h index 4e2e15fe..f8ed5018 100644 --- a/src/DHTMessageReceiver.h +++ b/src/DHTMessageReceiver.h @@ -69,7 +69,9 @@ public: ~DHTMessageReceiver(); - SharedHandle receiveMessage(); + SharedHandle receiveMessage + (const std::string& remoteAddr, uint16_t remotePort, unsigned char *data, + size_t length); void handleTimeout(); diff --git a/src/DHTSetup.cc b/src/DHTSetup.cc index 17bd035d..145e0316 100644 --- a/src/DHTSetup.cc +++ b/src/DHTSetup.cc @@ -61,6 +61,8 @@ #include "DHTRegistry.h" #include "DHTBucketRefreshTask.h" #include "DHTMessageCallback.h" +#include "UDPTrackerClient.h" +#include "BtRegistry.h" #include "prefs.h" #include "Option.h" #include "SocketCore.h" @@ -176,6 +178,8 @@ void DHTSetup::setup factory->setTokenTracker(tokenTracker.get()); factory->setLocalNode(localNode); + // For now, UDPTrackerClient was enabled along with DHT + SharedHandle udpTrackerClient(new UDPTrackerClient()); // assign them into DHTRegistry if(family == AF_INET) { DHTRegistry::getMutableData().localNode = localNode; @@ -187,6 +191,8 @@ void DHTSetup::setup DHTRegistry::getMutableData().messageDispatcher = dispatcher; DHTRegistry::getMutableData().messageReceiver = receiver; DHTRegistry::getMutableData().messageFactory = factory; + e->getBtRegistry()->setUDPTrackerClient(udpTrackerClient); + e->getBtRegistry()->setUdpPort(localNode->getPort()); } else { DHTRegistry::getMutableData6().localNode = localNode; DHTRegistry::getMutableData6().routingTable = routingTable; @@ -244,6 +250,8 @@ void DHTSetup::setup command->setMessageReceiver(receiver); command->setTaskQueue(taskQueue); command->setReadCheckSocket(connection->getSocket()); + command->setConnection(connection); + command->setUDPTrackerClient(udpTrackerClient); tempCommands->push_back(command); } { @@ -282,12 +290,15 @@ void DHTSetup::setup } commands.insert(commands.end(), tempCommands->begin(), tempCommands->end()); tempCommands->clear(); - } catch(RecoverableException& e) { + } catch(RecoverableException& ex) { A2_LOG_ERROR_EX(fmt("Exception caught while initializing DHT functionality." " DHT is disabled."), - e); + ex); if(family == AF_INET) { DHTRegistry::clearData(); + e->getBtRegistry()->setUDPTrackerClient + (SharedHandle()); + e->getBtRegistry()->setUdpPort(0); } else { DHTRegistry::clearData6(); } diff --git a/src/DefaultBtAnnounce.cc b/src/DefaultBtAnnounce.cc index 6a9cc713..a41f002d 100644 --- a/src/DefaultBtAnnounce.cc +++ b/src/DefaultBtAnnounce.cc @@ -52,6 +52,8 @@ #include "bittorrent_helper.h" #include "wallclock.h" #include "uri.h" +#include "UDPTrackerRequest.h" +#include "SocketCore.h" namespace aria2 { @@ -115,7 +117,7 @@ bool uriHasQuery(const std::string& uri) } } // namespace -std::string DefaultBtAnnounce::getAnnounceUrl() { +bool DefaultBtAnnounce::adjustAnnounceList() { if(isStoppedAnnounceReady()) { if(!announceList_.currentTierAcceptsStoppedEvent()) { announceList_.moveToStoppedAllowedTier(); @@ -135,6 +137,13 @@ std::string DefaultBtAnnounce::getAnnounceUrl() { announceList_.setEvent(AnnounceTier::STARTED_AFTER_COMPLETION); } } else { + return false; + } + return true; +} + +std::string DefaultBtAnnounce::getAnnounceUrl() { + if(!adjustAnnounceList()) { return A2STR::NIL; } int numWant = 50; @@ -193,6 +202,60 @@ std::string DefaultBtAnnounce::getAnnounceUrl() { return uri; } +SharedHandle DefaultBtAnnounce::createUDPTrackerRequest +(const std::string& remoteAddr, uint16_t remotePort, uint16_t localPort) +{ + if(!adjustAnnounceList()) { + return SharedHandle(); + } + NetStat& stat = downloadContext_->getNetStat(); + int64_t left = + pieceStorage_->getTotalLength()-pieceStorage_->getCompletedLength(); + SharedHandle req(new UDPTrackerRequest()); + req->remoteAddr = remoteAddr; + req->remotePort = remotePort; + req->action = UDPT_ACT_ANNOUNCE; + req->infohash = bittorrent::getTorrentAttrs(downloadContext_)->infoHash; + const unsigned char* peerId = bittorrent::getStaticPeerId(); + req->peerId.assign(peerId, peerId + PEER_ID_LENGTH); + req->downloaded = stat.getSessionDownloadLength(); + req->left = left; + req->uploaded = stat.getSessionUploadLength(); + switch(announceList_.getEvent()) { + case AnnounceTier::STARTED: + case AnnounceTier::STARTED_AFTER_COMPLETION: + req->event = UDPT_EVT_STARTED; + break; + case AnnounceTier::STOPPED: + req->event = UDPT_EVT_STOPPED; + break; + case AnnounceTier::COMPLETED: + req->event = UDPT_EVT_COMPLETED; + break; + default: + req->event = 0; + } + if(!option_->blank(PREF_BT_EXTERNAL_IP)) { + unsigned char dest[16]; + if(net::getBinAddr(dest, option_->get(PREF_BT_EXTERNAL_IP)) == 4) { + memcpy(&req->ip, dest, 4); + } else { + req->ip = 0; + } + } else { + req->ip = 0; + } + req->key = randomizer_->getRandomNumber(INT32_MAX); + int numWant = 50; + if(!btRuntime_->lessThanMinPeers() || btRuntime_->isHalt()) { + numWant = 0; + } + req->numWant = numWant; + req->port = localPort; + req->extensions = 0; + return req; +} + void DefaultBtAnnounce::announceStart() { ++trackers_; } @@ -287,6 +350,30 @@ DefaultBtAnnounce::processAnnounceResponse(const unsigned char* trackerResponse, } } +void DefaultBtAnnounce::processUDPTrackerResponse +(const SharedHandle& req) +{ + const SharedHandle& reply = req->reply; + A2_LOG_DEBUG("Now processing UDP tracker response."); + if(reply->interval > 0) { + minInterval_ = reply->interval; + A2_LOG_DEBUG(fmt("Min interval:%ld", static_cast(minInterval_))); + interval_ = minInterval_; + } + complete_ = reply->seeders; + A2_LOG_DEBUG(fmt("Complete:%d", reply->seeders)); + incomplete_ = reply->leechers; + A2_LOG_DEBUG(fmt("Incomplete:%d", reply->leechers)); + if(!btRuntime_->isHalt() && btRuntime_->lessThanMinPeers()) { + for(std::vector >::iterator i = + reply->peers.begin(), eoi = reply->peers.end(); i != eoi; + ++i) { + peerStorage_->addPeer(SharedHandle(new Peer((*i).first, + (*i).second))); + } + } +} + bool DefaultBtAnnounce::noMoreAnnounce() { return (trackers_ == 0 && btRuntime_->isHalt() && diff --git a/src/DefaultBtAnnounce.h b/src/DefaultBtAnnounce.h index a760a140..5428a42c 100644 --- a/src/DefaultBtAnnounce.h +++ b/src/DefaultBtAnnounce.h @@ -66,6 +66,8 @@ private: SharedHandle pieceStorage_; SharedHandle peerStorage_; uint16_t tcpPort_; + + bool adjustAnnounceList(); public: DefaultBtAnnounce(const SharedHandle& downloadContext, const Option* option); @@ -103,6 +105,10 @@ public: virtual std::string getAnnounceUrl(); + virtual SharedHandle + createUDPTrackerRequest(const std::string& remoteAddr, uint16_t remotePort, + uint16_t localPort); + virtual void announceStart(); virtual void announceSuccess(); @@ -116,6 +122,9 @@ public: virtual void processAnnounceResponse(const unsigned char* trackerResponse, size_t trackerResponseLength); + virtual void processUDPTrackerResponse + (const SharedHandle& req); + virtual bool noMoreAnnounce(); virtual void shuffleAnnounce(); diff --git a/src/DownloadEngine.cc b/src/DownloadEngine.cc index 9b42bf70..05478fe8 100644 --- a/src/DownloadEngine.cc +++ b/src/DownloadEngine.cc @@ -85,7 +85,7 @@ volatile sig_atomic_t globalHaltRequested = 0; DownloadEngine::DownloadEngine(const SharedHandle& eventPoll) : eventPoll_(eventPoll), - haltRequested_(false), + haltRequested_(0), noWait_(false), refreshInterval_(DEFAULT_REFRESH_INTERVAL), cookieStorage_(new CookieStorage()), @@ -239,13 +239,13 @@ void DownloadEngine::afterEachIteration() void DownloadEngine::requestHalt() { - haltRequested_ = true; + haltRequested_ = std::max(haltRequested_, 1); requestGroupMan_->halt(); } void DownloadEngine::requestForceHalt() { - haltRequested_ = true; + haltRequested_ = std::max(haltRequested_, 2); requestGroupMan_->forceHalt(); } diff --git a/src/DownloadEngine.h b/src/DownloadEngine.h index 05a8b03a..d949bff6 100644 --- a/src/DownloadEngine.h +++ b/src/DownloadEngine.h @@ -79,7 +79,7 @@ private: SharedHandle statCalc_; - bool haltRequested_; + int haltRequested_; class SocketPoolEntry { private: @@ -230,6 +230,11 @@ public: return haltRequested_; } + bool isForceHaltRequested() const + { + return haltRequested_ >= 2; + } + void requestHalt(); void requestForceHalt(); diff --git a/src/Makefile.am b/src/Makefile.am index bf5805b7..5388ba35 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -522,7 +522,10 @@ SRCS += PeerAbstractCommand.cc PeerAbstractCommand.h\ ValueBaseBencodeParser.h\ BencodeDiskWriter.h\ BencodeDiskWriterFactory.h\ - MemoryBencodePreDownloadHandler.h + MemoryBencodePreDownloadHandler.h\ + UDPTrackerClient.cc UDPTrackerClient.h\ + UDPTrackerRequest.cc UDPTrackerRequest.h\ + NameResolveCommand.cc NameResolveCommand.h endif # ENABLE_BITTORRENT if ENABLE_METALINK diff --git a/src/NameResolveCommand.cc b/src/NameResolveCommand.cc new file mode 100644 index 00000000..cf09ebb6 --- /dev/null +++ b/src/NameResolveCommand.cc @@ -0,0 +1,193 @@ +/* */ +#include "NameResolveCommand.h" +#include "DownloadEngine.h" +#include "NameResolver.h" +#include "prefs.h" +#include "message.h" +#include "util.h" +#include "Option.h" +#include "RequestGroupMan.h" +#include "Logger.h" +#include "LogFactory.h" +#include "fmt.h" +#include "UDPTrackerRequest.h" +#include "UDPTrackerClient.h" +#include "BtRegistry.h" +#ifdef ENABLE_ASYNC_DNS +#include "AsyncNameResolver.h" +#endif // ENABLE_ASYNC_DNS + +namespace aria2 { + +NameResolveCommand::NameResolveCommand +(cuid_t cuid, DownloadEngine* e, + const SharedHandle& req) + : Command(cuid), + e_(e), + req_(req) +{ + setStatus(Command::STATUS_ONESHOT_REALTIME); +} + +NameResolveCommand::~NameResolveCommand() +{ +#ifdef ENABLE_ASYNC_DNS + disableNameResolverCheck(resolver_); +#endif // ENABLE_ASYNC_DNS +} + +bool NameResolveCommand::execute() +{ + // This is UDP tracker specific, but we need to keep this command + // alive until force shutdown is + // commencing. RequestGroupMan::downloadFinished() is useless here + // at the moment. + if(e_->isForceHaltRequested()) { + onShutdown(); + return true; + } +#ifdef ENABLE_ASYNC_DNS + if(!resolver_) { + int family = AF_INET; + resolver_.reset(new AsyncNameResolver(family +#ifdef HAVE_ARES_ADDR_NODE + , e_->getAsyncDNSServers() +#endif // HAVE_ARES_ADDR_NODE + )); + } +#endif // ENABLE_ASYNC_DNS + std::string hostname = req_->remoteAddr; + std::vector res; + if(util::isNumericHost(hostname)) { + res.push_back(hostname); + } else { +#ifdef ENABLE_ASYNC_DNS + if(e_->getOption()->getAsBool(PREF_ASYNC_DNS)) { + try { + if(resolveHostname(hostname, resolver_)) { + res = resolver_->getResolvedAddresses(); + } else { + e_->addCommand(this); + return false; + } + } catch(RecoverableException& e) { + A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e); + } + } else +#endif // ENABLE_ASYNC_DNS + { + NameResolver resolver; + resolver.setSocktype(SOCK_DGRAM); + try { + resolver.resolve(res, hostname); + } catch(RecoverableException& e) { + A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e); + } + } + } + if(res.empty()) { + onFailure(); + } else { + onSuccess(res, e_); + } + return true; +} + +void NameResolveCommand::onShutdown() +{ + req_->state = UDPT_STA_COMPLETE; + req_->error = UDPT_ERR_SHUTDOWN; +} + +void NameResolveCommand::onFailure() +{ + req_->state = UDPT_STA_COMPLETE; + req_->error = UDPT_ERR_NETWORK; +} + +void NameResolveCommand::onSuccess +(const std::vector& addrs, DownloadEngine* e) +{ + req_->remoteAddr = addrs[0]; + e->getBtRegistry()->getUDPTrackerClient()->addRequest(req_); +} + +#ifdef ENABLE_ASYNC_DNS + +bool NameResolveCommand::resolveHostname +(const std::string& hostname, + const SharedHandle& resolver) +{ + switch(resolver->getStatus()) { + case AsyncNameResolver::STATUS_READY: + A2_LOG_INFO(fmt(MSG_RESOLVING_HOSTNAME, + getCuid(), + hostname.c_str())); + resolver->resolve(hostname); + setNameResolverCheck(resolver); + return false; + case AsyncNameResolver::STATUS_SUCCESS: + A2_LOG_INFO(fmt(MSG_NAME_RESOLUTION_COMPLETE, + getCuid(), + resolver->getHostname().c_str(), + resolver->getResolvedAddresses().front().c_str())); + return true; + break; + case AsyncNameResolver::STATUS_ERROR: + throw DL_ABORT_EX + (fmt(MSG_NAME_RESOLUTION_FAILED, + getCuid(), + hostname.c_str(), + resolver->getError().c_str())); + default: + return false; + } +} + +void NameResolveCommand::setNameResolverCheck +(const SharedHandle& resolver) +{ + e_->addNameResolverCheck(resolver, this); +} + +void NameResolveCommand::disableNameResolverCheck +(const SharedHandle& resolver) +{ + e_->deleteNameResolverCheck(resolver, this); +} +#endif // ENABLE_ASYNC_DNS + +} // namespace aria2 diff --git a/src/NameResolveCommand.h b/src/NameResolveCommand.h new file mode 100644 index 00000000..0d3aa8a8 --- /dev/null +++ b/src/NameResolveCommand.h @@ -0,0 +1,87 @@ +/* */ +#ifndef D_NAME_RESOLVE_COMMAND_H +#define D_NAME_RESOLVE_COMMAND_H + +#include "Command.h" + +#include +#include + +#include "SharedHandle.h" + +// TODO Make this class generic. +namespace aria2 { + +class DownloadEngine; +#ifdef ENABLE_ASYNC_DNS +class AsyncNameResolver; +#endif // ENABLE_ASYNC_DNS +class UDPTrackerRequest; + +class NameResolveCommand:public Command { +private: + DownloadEngine* e_; + +#ifdef ENABLE_ASYNC_DNS + SharedHandle resolver_; +#endif // ENABLE_ASYNC_DNS + +#ifdef ENABLE_ASYNC_DNS + bool resolveHostname(const std::string& hostname, + const SharedHandle& resolver); + + void setNameResolverCheck(const SharedHandle& resolver); + + void disableNameResolverCheck(const SharedHandle& resolver); +#endif // ENABLE_ASYNC_DNS + + SharedHandle req_; + void onShutdown(); + void onFailure(); + void onSuccess + (const std::vector& addrs, DownloadEngine* e); +public: + NameResolveCommand(cuid_t cuid, DownloadEngine* e, + const SharedHandle& req); + + virtual ~NameResolveCommand(); + + virtual bool execute(); +}; + +} // namespace aria2 + +#endif // D_NAME_RESOVE_COMMAND_H diff --git a/src/TrackerWatcherCommand.cc b/src/TrackerWatcherCommand.cc index 0eba057f..06d7fc3e 100644 --- a/src/TrackerWatcherCommand.cc +++ b/src/TrackerWatcherCommand.cc @@ -63,32 +63,156 @@ #include "a2functional.h" #include "util.h" #include "fmt.h" +#include "UDPTrackerRequest.h" +#include "UDPTrackerClient.h" +#include "BtRegistry.h" +#include "NameResolveCommand.h" namespace aria2 { +HTTPAnnRequest::HTTPAnnRequest(const SharedHandle& rg) + : rg_(rg) +{} + +HTTPAnnRequest::~HTTPAnnRequest() +{} + +bool HTTPAnnRequest::stopped() const +{ + return rg_->getNumCommand() == 0; +} + +bool HTTPAnnRequest::success() const +{ + return rg_->downloadFinished(); +} + +void HTTPAnnRequest::stop(DownloadEngine* e) +{ + rg_->setForceHaltRequested(true); +} + +bool HTTPAnnRequest::issue(DownloadEngine* e) +{ + try { + std::vector* commands = new std::vector(); + auto_delete_container > commandsDel(commands); + rg_->createInitialCommand(*commands, e); + e->addCommand(*commands); + e->setNoWait(true); + commands->clear(); + A2_LOG_DEBUG("added tracker request command"); + return true; + } catch(RecoverableException& ex) { + A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex); + return false; + } +} + +bool HTTPAnnRequest::processResponse +(const SharedHandle& btAnnounce) +{ + try { + std::stringstream strm; + unsigned char data[2048]; + rg_->getPieceStorage()->getDiskAdaptor()->openFile(); + while(1) { + ssize_t dataLength = rg_->getPieceStorage()-> + getDiskAdaptor()->readData(data, sizeof(data), strm.tellp()); + if(dataLength == 0) { + break; + } + strm.write(reinterpret_cast(data), dataLength); + } + std::string res = strm.str(); + btAnnounce->processAnnounceResponse + (reinterpret_cast(res.c_str()), res.size()); + return true; + } catch(RecoverableException& e) { + A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e); + return false; + } +} + +UDPAnnRequest::UDPAnnRequest(const SharedHandle& req) + : req_(req) +{} + +UDPAnnRequest::~UDPAnnRequest() +{} + +bool UDPAnnRequest::stopped() const +{ + return !req_ || req_->state == UDPT_STA_COMPLETE; +} + +bool UDPAnnRequest::success() const +{ + return req_ && req_->state == UDPT_STA_COMPLETE && + req_->error == UDPT_ERR_SUCCESS; +} + +void UDPAnnRequest::stop(DownloadEngine* e) +{ + if(req_) { + req_.reset(); + } +} + +bool UDPAnnRequest::issue(DownloadEngine* e) +{ + if(req_) { + NameResolveCommand* command = new NameResolveCommand + (e->newCUID(), e, req_); + e->addCommand(command); + e->setNoWait(true); + return true; + } else { + return false; + } +} + +bool UDPAnnRequest::processResponse +(const SharedHandle& btAnnounce) +{ + if(req_) { + btAnnounce->processUDPTrackerResponse(req_); + return true; + } else { + return false; + } +} + TrackerWatcherCommand::TrackerWatcherCommand (cuid_t cuid, RequestGroup* requestGroup, DownloadEngine* e) : Command(cuid), requestGroup_(requestGroup), - e_(e) + e_(e), + udpTrackerClient_(e_->getBtRegistry()->getUDPTrackerClient()) { requestGroup_->increaseNumCommand(); + if(udpTrackerClient_) { + udpTrackerClient_->increaseWatchers(); + } } TrackerWatcherCommand::~TrackerWatcherCommand() { requestGroup_->decreaseNumCommand(); + if(udpTrackerClient_) { + udpTrackerClient_->decreaseWatchers(); + } } bool TrackerWatcherCommand::execute() { if(requestGroup_->isForceHaltRequested()) { - if(!trackerRequestGroup_) { + if(!trackerRequest_) { return true; - } else if(trackerRequestGroup_->getNumCommand() == 0 || - trackerRequestGroup_->downloadFinished()) { + } else if(trackerRequest_->stopped() || + trackerRequest_->success()) { return true; } else { - trackerRequestGroup_->setForceHaltRequested(true); + trackerRequest_->stop(e_); e_->setRefreshInterval(0); e_->addCommand(this); return false; @@ -98,44 +222,32 @@ bool TrackerWatcherCommand::execute() { A2_LOG_DEBUG("no more announce"); return true; } - if(!trackerRequestGroup_) { - trackerRequestGroup_ = createAnnounce(); - if(trackerRequestGroup_) { - try { - std::vector* commands = new std::vector(); - auto_delete_container > commandsDel(commands); - trackerRequestGroup_->createInitialCommand(*commands, e_); - e_->addCommand(*commands); - commands->clear(); - A2_LOG_DEBUG("added tracker request command"); - } catch(RecoverableException& ex) { - A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex); - } + if(!trackerRequest_) { + trackerRequest_ = createAnnounce(e_); + if(trackerRequest_) { + trackerRequest_->issue(e_); } - } else if(trackerRequestGroup_->getNumCommand() == 0) { + } else if(trackerRequest_->stopped()) { // We really want to make sure that tracker request has finished // by checking getNumCommand() == 0. Because we reset // trackerRequestGroup_, if it is still used in other Command, we // will get Segmentation fault. - if(trackerRequestGroup_->downloadFinished()) { - try { - std::string trackerResponse = getTrackerResponse(trackerRequestGroup_); - - processTrackerResponse(trackerResponse); + if(trackerRequest_->success()) { + if(trackerRequest_->processResponse(btAnnounce_)) { btAnnounce_->announceSuccess(); btAnnounce_->resetAnnounce(); - } catch(RecoverableException& ex) { - A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex); + addConnection(); + } else { btAnnounce_->announceFailure(); if(btAnnounce_->isAllAnnounceFailed()) { btAnnounce_->resetAnnounce(); } } - trackerRequestGroup_.reset(); + trackerRequest_.reset(); } else { // handle errors here btAnnounce_->announceFailure(); // inside it, trackers = 0. - trackerRequestGroup_.reset(); + trackerRequest_.reset(); if(btAnnounce_->isAllAnnounceFailed()) { btAnnounce_->resetAnnounce(); } @@ -145,30 +257,8 @@ bool TrackerWatcherCommand::execute() { return false; } -std::string TrackerWatcherCommand::getTrackerResponse -(const SharedHandle& requestGroup) +void TrackerWatcherCommand::addConnection() { - std::stringstream strm; - unsigned char data[2048]; - requestGroup->getPieceStorage()->getDiskAdaptor()->openFile(); - while(1) { - ssize_t dataLength = requestGroup->getPieceStorage()-> - getDiskAdaptor()->readData(data, sizeof(data), strm.tellp()); - if(dataLength == 0) { - break; - } - strm.write(reinterpret_cast(data), dataLength); - } - return strm.str(); -} - -// TODO we have to deal with the exception thrown By BtAnnounce -void TrackerWatcherCommand::processTrackerResponse -(const std::string& trackerResponse) -{ - btAnnounce_->processAnnounceResponse - (reinterpret_cast(trackerResponse.c_str()), - trackerResponse.size()); while(!btRuntime_->isHalt() && btRuntime_->lessThanMinPeers()) { if(!peerStorage_->isPeerAvailable()) { break; @@ -180,8 +270,8 @@ void TrackerWatcherCommand::processTrackerResponse break; } PeerInitiateConnectionCommand* command; - command = new PeerInitiateConnectionCommand(ncuid, requestGroup_, peer, e_, - btRuntime_); + command = new PeerInitiateConnectionCommand(ncuid, requestGroup_, peer, + e_, btRuntime_); command->setPeerStorage(peerStorage_); command->setPieceStorage(pieceStorage_); e_->addCommand(command); @@ -190,13 +280,48 @@ void TrackerWatcherCommand::processTrackerResponse } } -SharedHandle TrackerWatcherCommand::createAnnounce() { - SharedHandle rg; - if(btAnnounce_->isAnnounceReady()) { - rg = createRequestGroup(btAnnounce_->getAnnounceUrl()); - btAnnounce_->announceStart(); // inside it, trackers++. +SharedHandle +TrackerWatcherCommand::createAnnounce(DownloadEngine* e) +{ + SharedHandle treq; + while(!btAnnounce_->isAllAnnounceFailed() && + btAnnounce_->isAnnounceReady()) { + std::string uri = btAnnounce_->getAnnounceUrl(); + uri_split_result res; + memset(&res, 0, sizeof(res)); + if(uri_split(&res, uri.c_str()) == 0) { + // Without UDP tracker support, send it to normal tracker flow + // and make it fail. + if(udpTrackerClient_ && + uri::getFieldString(res, USR_SCHEME, uri.c_str()) == "udp") { + uint16_t localPort; + localPort = e->getBtRegistry()->getUdpPort(); + treq = createUDPAnnRequest + (uri::getFieldString(res, USR_HOST, uri.c_str()), res.port, + localPort); + } else { + treq = createHTTPAnnRequest(btAnnounce_->getAnnounceUrl()); + } + btAnnounce_->announceStart(); // inside it, trackers++. + break; + } else { + btAnnounce_->announceFailure(); + } } - return rg; + if(btAnnounce_->isAllAnnounceFailed()) { + btAnnounce_->resetAnnounce(); + } + return treq; +} + +SharedHandle +TrackerWatcherCommand::createUDPAnnRequest(const std::string& host, + uint16_t port, + uint16_t localPort) +{ + SharedHandle req = + btAnnounce_->createUDPTrackerRequest(host, port, localPort); + return SharedHandle(new UDPAnnRequest(req)); } namespace { @@ -219,8 +344,8 @@ bool backupTrackerIsAvailable } } // namespace -SharedHandle -TrackerWatcherCommand::createRequestGroup(const std::string& uri) +SharedHandle +TrackerWatcherCommand::createHTTPAnnRequest(const std::string& uri) { std::vector uris; uris.push_back(uri); @@ -261,7 +386,7 @@ TrackerWatcherCommand::createRequestGroup(const std::string& uri) dctx->setAcceptMetalink(false); A2_LOG_INFO(fmt("Creating tracker request group GID#%s", GroupId::toHex(rg->getGID()).c_str())); - return rg; + return SharedHandle(new HTTPAnnRequest(rg)); } void TrackerWatcherCommand::setBtRuntime diff --git a/src/TrackerWatcherCommand.h b/src/TrackerWatcherCommand.h index 73887f8a..99ff91b4 100644 --- a/src/TrackerWatcherCommand.h +++ b/src/TrackerWatcherCommand.h @@ -50,6 +50,50 @@ class PieceStorage; class BtRuntime; class BtAnnounce; class Option; +class UDPTrackerRequest; +class UDPTrackerClient; + +class AnnRequest { +public: + virtual ~AnnRequest() {} + // Returns true if tracker request is finished, regardless of the + // outcome. + virtual bool stopped() const = 0; + // Returns true if tracker request is successful. + virtual bool success() const = 0; + // Returns true if issuing request is successful. + virtual bool issue(DownloadEngine* e) = 0; + // Stop this request. + virtual void stop(DownloadEngine* e) = 0; + // Returns true if processing tracker response is successful. + virtual bool processResponse(const SharedHandle& btAnnounce) = 0; +}; + +class HTTPAnnRequest:public AnnRequest { +public: + HTTPAnnRequest(const SharedHandle& rg); + virtual ~HTTPAnnRequest(); + virtual bool stopped() const; + virtual bool success() const; + virtual bool issue(DownloadEngine* e); + virtual void stop(DownloadEngine* e); + virtual bool processResponse(const SharedHandle& btAnnounce); +private: + SharedHandle rg_; +}; + +class UDPAnnRequest:public AnnRequest { +public: + UDPAnnRequest(const SharedHandle& req); + virtual ~UDPAnnRequest(); + virtual bool stopped() const; + virtual bool success() const; + virtual bool issue(DownloadEngine* e); + virtual void stop(DownloadEngine* e); + virtual bool processResponse(const SharedHandle& btAnnounce); +private: + SharedHandle req_; +}; class TrackerWatcherCommand : public Command { @@ -58,6 +102,8 @@ private: DownloadEngine* e_; + SharedHandle udpTrackerClient_; + SharedHandle peerStorage_; SharedHandle pieceStorage_; @@ -66,16 +112,20 @@ private: SharedHandle btAnnounce_; - SharedHandle trackerRequestGroup_; + SharedHandle trackerRequest_; + /** * Returns a command for announce request. Returns 0 if no announce request * is needed. */ - SharedHandle createRequestGroup(const std::string& url); + SharedHandle + createHTTPAnnRequest(const std::string& uri); - std::string getTrackerResponse(const SharedHandle& requestGroup); + SharedHandle + createUDPAnnRequest(const std::string& host, uint16_t port, + uint16_t localPort); - void processTrackerResponse(const std::string& response); + void addConnection(); const SharedHandle