diff --git a/src/DHTGetPeersMessage.cc b/src/DHTGetPeersMessage.cc index a2857e1d..ff5c8ba2 100644 --- a/src/DHTGetPeersMessage.cc +++ b/src/DHTGetPeersMessage.cc @@ -46,6 +46,10 @@ #include "DHTTokenTracker.h" #include "DHTGetPeersReplyMessage.h" #include "util.h" +#include "BtRegistry.h" +#include "DownloadContext.h" +#include "Option.h" +#include "SocketCore.h" namespace aria2 { @@ -59,18 +63,61 @@ DHTGetPeersMessage::DHTGetPeersMessage( const std::string& transactionID) : DHTQueryMessage{localNode, remoteNode, transactionID}, peerAnnounceStorage_{nullptr}, - tokenTracker_{nullptr} + tokenTracker_{nullptr}, + btRegistry_{nullptr}, + family_{AF_INET} { memcpy(infoHash_, infoHash, DHT_ID_LENGTH); } +void DHTGetPeersMessage::addLocalPeer(std::vector>& peers) +{ + if (!btRegistry_) { + return; + } + + auto& dctx = btRegistry_->getDownloadContext( + std::string(infoHash_, infoHash_ + DHT_ID_LENGTH)); + + if (!dctx) { + return; + } + + auto group = dctx->getOwnerRequestGroup(); + auto& option = group->getOption(); + auto& externalIP = option->get(PREF_BT_EXTERNAL_IP); + + if (externalIP.empty()) { + return; + } + + std::array dst; + if (inetPton(family_, externalIP.c_str(), dst.data()) == -1) { + return; + } + + auto tcpPort = btRegistry_->getTcpPort(); + if (std::find_if(std::begin(peers), std::end(peers), + [&externalIP, tcpPort](const std::shared_ptr& peer) { + return peer->getIPAddress() == externalIP && + peer->getPort() == tcpPort; + }) != std::end(peers)) { + return; + } + + peers.push_back(std::make_shared(externalIP, tcpPort)); +} + void DHTGetPeersMessage::doReceivedAction() { std::string token = tokenTracker_->generateToken( infoHash_, getRemoteNode()->getIPAddress(), getRemoteNode()->getPort()); - // Check to see localhost has the contents which has same infohash std::vector> peers; peerAnnounceStorage_->getPeers(peers, infoHash_); + + // Check to see localhost has the contents which has same infohash + addLocalPeer(peers); + std::vector> nodes; getRoutingTable()->getClosestKNodes(nodes, infoHash_); getMessageDispatcher()->addMessageToQueue( @@ -102,6 +149,13 @@ void DHTGetPeersMessage::setTokenTracker(DHTTokenTracker* tokenTracker) tokenTracker_ = tokenTracker; } +void DHTGetPeersMessage::setBtRegistry(BtRegistry* btRegistry) +{ + btRegistry_ = btRegistry; +} + +void DHTGetPeersMessage::setFamily(int family) { family_ = family; } + std::string DHTGetPeersMessage::toStringOptional() const { return "info_hash=" + util::toHex(infoHash_, INFO_HASH_LENGTH); diff --git a/src/DHTGetPeersMessage.h b/src/DHTGetPeersMessage.h index 529e8098..d4da04e8 100644 --- a/src/DHTGetPeersMessage.h +++ b/src/DHTGetPeersMessage.h @@ -36,6 +36,9 @@ #define D_DHT_GET_PEERS_MESSAGE_H #include "DHTQueryMessage.h" + +#include + #include "DHTConstants.h" #include "A2STR.h" @@ -43,6 +46,8 @@ namespace aria2 { class DHTPeerAnnounceStorage; class DHTTokenTracker; +class BtRegistry; +class Peer; class DHTGetPeersMessage : public DHTQueryMessage { private: @@ -52,6 +57,12 @@ private: DHTTokenTracker* tokenTracker_; + BtRegistry* btRegistry_; + + int family_; + + void addLocalPeer(std::vector>& peers); + protected: virtual std::string toStringOptional() const CXX11_OVERRIDE; @@ -73,6 +84,10 @@ public: void setTokenTracker(DHTTokenTracker* tokenTracker); + void setBtRegistry(BtRegistry* btRegistry); + + void setFamily(int family); + static const std::string GET_PEERS; static const std::string INFO_HASH; diff --git a/src/DHTMessageFactoryImpl.cc b/src/DHTMessageFactoryImpl.cc index 9e50fdd6..33994133 100644 --- a/src/DHTMessageFactoryImpl.cc +++ b/src/DHTMessageFactoryImpl.cc @@ -70,7 +70,8 @@ DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family) dispatcher_{nullptr}, routingTable_{nullptr}, peerAnnounceStorage_{nullptr}, - tokenTracker_{nullptr} + tokenTracker_{nullptr}, + btRegistry_{nullptr} { } @@ -409,6 +410,8 @@ DHTMessageFactoryImpl::createGetPeersMessage( transactionID); m->setPeerAnnounceStorage(peerAnnounceStorage_); m->setTokenTracker(tokenTracker_); + m->setBtRegistry(btRegistry_); + m->setFamily(family_); setCommonProperty(m.get()); return m; } @@ -529,4 +532,9 @@ void DHTMessageFactoryImpl::setLocalNode( localNode_ = localNode; } +void DHTMessageFactoryImpl::setBtRegistry(BtRegistry* btRegistry) +{ + btRegistry_ = btRegistry; +} + } // namespace aria2 diff --git a/src/DHTMessageFactoryImpl.h b/src/DHTMessageFactoryImpl.h index 0035d7cc..b6ac69f4 100644 --- a/src/DHTMessageFactoryImpl.h +++ b/src/DHTMessageFactoryImpl.h @@ -47,6 +47,7 @@ class DHTPeerAnnounceStorage; class DHTTokenTracker; class DHTMessage; class DHTAbstractMessage; +class BtRegistry; class DHTMessageFactoryImpl : public DHTMessageFactory { private: @@ -64,6 +65,8 @@ private: DHTTokenTracker* tokenTracker_; + BtRegistry* btRegistry_; + // search node in routingTable. If it is not found, create new one. std::shared_ptr getRemoteNode(const unsigned char* id, const std::string& ipaddr, @@ -154,6 +157,8 @@ public: void setTokenTracker(DHTTokenTracker* tokenTracker); void setLocalNode(const std::shared_ptr& localNode); + + void setBtRegistry(BtRegistry* btRegistry); }; } // namespace aria2 diff --git a/src/DHTSetup.cc b/src/DHTSetup.cc index 45ec95a0..7ac7fc91 100644 --- a/src/DHTSetup.cc +++ b/src/DHTSetup.cc @@ -180,6 +180,7 @@ DHTSetup::setup(DownloadEngine* e, int family) factory->setPeerAnnounceStorage(peerAnnounceStorage.get()); factory->setTokenTracker(tokenTracker.get()); factory->setLocalNode(localNode); + factory->setBtRegistry(e->getBtRegistry().get()); PrefPtr prefEntryPointHost = family == AF_INET ? PREF_DHT_ENTRY_POINT_HOST : PREF_DHT_ENTRY_POINT_HOST6; diff --git a/src/usage_text.h b/src/usage_text.h index e976eea1..65367346 100644 --- a/src/usage_text.h +++ b/src/usage_text.h @@ -538,9 +538,13 @@ #define TEXT_EVENT_POLL \ _(" --event-poll=POLL Specify the method for polling events.") #define TEXT_BT_EXTERNAL_IP \ - _(" --bt-external-ip=IPADDRESS Specify the external IP address to report to a\n" \ - " BitTorrent tracker. Although this function is\n" \ - " named 'external', it can accept any kind of IP\n" \ + _(" --bt-external-ip=IPADDRESS Specify the external IP address to use in\n" \ + " BitTorrent download and DHT. It may be sent to\n" \ + " BitTorrent tracker. For DHT, this option should\n" \ + " be set to report that local node is downloading\n" \ + " a particular torrent. This is critical to use\n" \ + " DHT in a private network. Although this function\n" \ + " is named 'external', it can accept any kind of IP\n" \ " addresses.") #define TEXT_HTTP_AUTH_CHALLENGE \ _(" --http-auth-challenge[=true|false] Send HTTP authorization header only when it\n" \ diff --git a/test/DHTGetPeersMessageTest.cc b/test/DHTGetPeersMessageTest.cc index 07aecc08..cac07ca4 100644 --- a/test/DHTGetPeersMessageTest.cc +++ b/test/DHTGetPeersMessageTest.cc @@ -12,6 +12,12 @@ #include "DHTPeerAnnounceStorage.h" #include "DHTRoutingTable.h" #include "bencode2.h" +#include "GroupId.h" +#include "DownloadContext.h" +#include "Option.h" +#include "RequestGroup.h" +#include "BtRegistry.h" +#include "TorrentAttribute.h" namespace aria2 { @@ -102,11 +108,32 @@ void DHTGetPeersMessageTest::testDoReceivedAction() factory.setLocalNode(localNode_); DHTRoutingTable routingTable(localNode_); + auto torrentAttrs = std::make_shared(); + torrentAttrs->infoHash = std::string(infoHash, infoHash + DHT_ID_LENGTH); + + auto dctx = std::make_shared(); + dctx->setAttribute(CTX_ATTR_BT, torrentAttrs); + + auto option = std::make_shared