Use std::unique_ptr to store DHTMessages instead of std::shared_ptr

This commit is contained in:
Tatsuhiro Tsujikawa 2013-07-02 22:58:20 +09:00
parent 4f7d1c395b
commit 1a5d75e819
53 changed files with 833 additions and 872 deletions

View file

@ -46,17 +46,16 @@
namespace aria2 { namespace aria2 {
DHTAbstractMessage::DHTAbstractMessage(const std::shared_ptr<DHTNode>& localNode, DHTAbstractMessage::DHTAbstractMessage
const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& localNode,
const std::string& transactionID): const std::shared_ptr<DHTNode>& remoteNode,
DHTMessage(localNode, remoteNode, transactionID), const std::string& transactionID)
connection_(0), : DHTMessage{localNode, remoteNode, transactionID},
dispatcher_(0), connection_{nullptr},
factory_(0), dispatcher_{nullptr},
routingTable_(0) factory_{nullptr},
{} routingTable_{nullptr}
{}
DHTAbstractMessage::~DHTAbstractMessage() {}
std::string DHTAbstractMessage::getBencodedMessage() std::string DHTAbstractMessage::getBencodedMessage()
{ {

View file

@ -60,8 +60,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual ~DHTAbstractMessage();
virtual bool send(); virtual bool send();
virtual const std::string& getType() const = 0; virtual const std::string& getType() const = 0;

View file

@ -67,32 +67,29 @@ class DHTAbstractNodeLookupTask:public DHTAbstractTask {
private: private:
unsigned char targetID_[DHT_ID_LENGTH]; unsigned char targetID_[DHT_ID_LENGTH];
std::deque<std::shared_ptr<DHTNodeLookupEntry> > entries_; std::deque<std::unique_ptr<DHTNodeLookupEntry>> entries_;
size_t inFlightMessage_; size_t inFlightMessage_;
template<typename Container> template<typename Container>
void toEntries void toEntries
(Container& entries, const std::vector<std::shared_ptr<DHTNode> >& nodes) const (Container& entries,
const std::vector<std::shared_ptr<DHTNode>>& nodes) const
{ {
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i = nodes.begin(), for(auto& node : nodes) {
eoi = nodes.end(); i != eoi; ++i) { entries.push_back(make_unique<DHTNodeLookupEntry>(node));
std::shared_ptr<DHTNodeLookupEntry> e(new DHTNodeLookupEntry(*i));
entries.push_back(e);
} }
} }
void sendMessage() void sendMessage()
{ {
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i = for(auto i = std::begin(entries_), eoi = std::end(entries_);
entries_.begin(), eoi = entries_.end();
i != eoi && inFlightMessage_ < ALPHA; ++i) { i != eoi && inFlightMessage_ < ALPHA; ++i) {
if((*i)->used == false) { if((*i)->used == false) {
++inFlightMessage_; ++inFlightMessage_;
(*i)->used = true; (*i)->used = true;
std::shared_ptr<DHTMessage> m = createMessage((*i)->node); getMessageDispatcher()->addMessageToQueue
std::shared_ptr<DHTMessageCallback> callback(createCallback()); (createMessage((*i)->node), createCallback());
getMessageDispatcher()->addMessageToQueue(m, callback);
} }
} }
} }
@ -122,13 +119,13 @@ protected:
return targetID_; return targetID_;
} }
const std::deque<std::shared_ptr<DHTNodeLookupEntry> >& getEntries() const const std::deque<std::unique_ptr<DHTNodeLookupEntry>>& getEntries() const
{ {
return entries_; return entries_;
} }
virtual void getNodesFromMessage virtual void getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode>>& nodes,
const ResponseMessage* message) = 0; const ResponseMessage* message) = 0;
virtual void onReceivedInternal virtual void onReceivedInternal
@ -138,10 +135,10 @@ protected:
virtual void onFinish() {} virtual void onFinish() {}
virtual std::shared_ptr<DHTMessage> createMessage virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode) = 0; (const std::shared_ptr<DHTNode>& remoteNode) = 0;
virtual std::shared_ptr<DHTMessageCallback> createCallback() = 0; virtual std::unique_ptr<DHTMessageCallback> createCallback() = 0;
public: public:
DHTAbstractNodeLookupTask(const unsigned char* targetID): DHTAbstractNodeLookupTask(const unsigned char* targetID):
inFlightMessage_(0) inFlightMessage_(0)
@ -153,7 +150,7 @@ public:
virtual void startup() virtual void startup()
{ {
std::vector<std::shared_ptr<DHTNode> > nodes; std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, targetID_); getRoutingTable()->getClosestKNodes(nodes, targetID_);
entries_.clear(); entries_.clear();
toEntries(entries_, nodes); toEntries(entries_, nodes);
@ -174,43 +171,42 @@ public:
{ {
--inFlightMessage_; --inFlightMessage_;
// Replace old Node ID with new Node ID. // Replace old Node ID with new Node ID.
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i = for(auto& entry : entries_) {
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) { if(entry->node->getIPAddress() == message->getRemoteNode()->getIPAddress()
if((*i)->node->getIPAddress() == message->getRemoteNode()->getIPAddress() && entry->node->getPort() == message->getRemoteNode()->getPort()) {
&& (*i)->node->getPort() == message->getRemoteNode()->getPort()) { entry->node = message->getRemoteNode();
(*i)->node = message->getRemoteNode();
} }
} }
onReceivedInternal(message); onReceivedInternal(message);
std::vector<std::shared_ptr<DHTNode> > nodes; std::vector<std::shared_ptr<DHTNode>> nodes;
getNodesFromMessage(nodes, message); getNodesFromMessage(nodes, message);
std::vector<std::shared_ptr<DHTNodeLookupEntry> > newEntries; std::vector<std::unique_ptr<DHTNodeLookupEntry>> newEntries;
toEntries(newEntries, nodes); toEntries(newEntries, nodes);
size_t count = 0; size_t count = 0;
for(std::vector<std::shared_ptr<DHTNodeLookupEntry> >::const_iterator i = for(auto& ne : newEntries) {
newEntries.begin(), eoi = newEntries.end(); i != eoi; ++i) { if(memcmp(getLocalNode()->getID(), ne->node->getID(),
if(memcmp(getLocalNode()->getID(), (*i)->node->getID(),
DHT_ID_LENGTH) != 0) { DHT_ID_LENGTH) != 0) {
entries_.push_front(*i);
++count;
A2_LOG_DEBUG(fmt("Received nodes: id=%s, ip=%s", A2_LOG_DEBUG(fmt("Received nodes: id=%s, ip=%s",
util::toHex((*i)->node->getID(), util::toHex(ne->node->getID(),
DHT_ID_LENGTH).c_str(), DHT_ID_LENGTH).c_str(),
(*i)->node->getIPAddress().c_str())); ne->node->getIPAddress().c_str()));
entries_.push_front(std::move(ne));
++count;
} }
} }
A2_LOG_DEBUG(fmt("%lu node lookup entries added.", A2_LOG_DEBUG(fmt("%lu node lookup entries added.",
static_cast<unsigned long>(count))); static_cast<unsigned long>(count)));
std::stable_sort(entries_.begin(), entries_.end(), DHTIDCloser(targetID_)); std::stable_sort(std::begin(entries_), std::end(entries_),
DHTIDCloser(targetID_));
entries_.erase entries_.erase
(std::unique(entries_.begin(), entries_.end(), (std::unique(std::begin(entries_), std::end(entries_),
DerefEqualTo<std::shared_ptr<DHTNodeLookupEntry> >()), DerefEqualTo<std::unique_ptr<DHTNodeLookupEntry>>{}),
entries_.end()); std::end(entries_));
A2_LOG_DEBUG(fmt("%lu node lookup entries are unique.", A2_LOG_DEBUG(fmt("%lu node lookup entries are unique.",
static_cast<unsigned long>(entries_.size()))); static_cast<unsigned long>(entries_.size())));
if(entries_.size() > DHTBucket::K) { if(entries_.size() > DHTBucket::K) {
entries_.erase(entries_.begin()+DHTBucket::K, entries_.end()); entries_.erase(std::begin(entries_)+DHTBucket::K, std::end(entries_));
} }
sendMessageAndCheckFinish(); sendMessageAndCheckFinish();
} }
@ -220,8 +216,8 @@ public:
A2_LOG_DEBUG(fmt("node lookup message timeout for node ID=%s", A2_LOG_DEBUG(fmt("node lookup message timeout for node ID=%s",
util::toHex(node->getID(), DHT_ID_LENGTH).c_str())); util::toHex(node->getID(), DHT_ID_LENGTH).c_str()));
--inFlightMessage_; --inFlightMessage_;
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i = for(auto i = std::begin(entries_), eoi = std::end(entries_);
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) { i != eoi; ++i) {
if(*(*i)->node == *node) { if(*(*i)->node == *node) {
entries_.erase(i); entries_.erase(i);
break; break;

View file

@ -44,6 +44,7 @@
#include "util.h" #include "util.h"
#include "DHTPeerAnnounceStorage.h" #include "DHTPeerAnnounceStorage.h"
#include "DHTTokenTracker.h" #include "DHTTokenTracker.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "BtConstants.h" #include "BtConstants.h"
#include "fmt.h" #include "fmt.h"
@ -65,32 +66,29 @@ DHTAnnouncePeerMessage::DHTAnnouncePeerMessage
const unsigned char* infoHash, const unsigned char* infoHash,
uint16_t tcpPort, uint16_t tcpPort,
const std::string& token, const std::string& token,
const std::string& transactionID): const std::string& transactionID)
DHTQueryMessage(localNode, remoteNode, transactionID), : DHTQueryMessage{localNode, remoteNode, transactionID},
token_(token), token_{token},
tcpPort_(tcpPort), tcpPort_{tcpPort},
peerAnnounceStorage_(0), peerAnnounceStorage_{nullptr},
tokenTracker_(0) tokenTracker_{nullptr}
{ {
memcpy(infoHash_, infoHash, DHT_ID_LENGTH); memcpy(infoHash_, infoHash, DHT_ID_LENGTH);
} }
DHTAnnouncePeerMessage::~DHTAnnouncePeerMessage() {}
void DHTAnnouncePeerMessage::doReceivedAction() void DHTAnnouncePeerMessage::doReceivedAction()
{ {
peerAnnounceStorage_->addPeerAnnounce peerAnnounceStorage_->addPeerAnnounce
(infoHash_, getRemoteNode()->getIPAddress(), tcpPort_); (infoHash_, getRemoteNode()->getIPAddress(), tcpPort_);
std::shared_ptr<DHTMessage> reply = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createAnnouncePeerReplyMessage (getMessageFactory()->createAnnouncePeerReplyMessage
(getRemoteNode(), getTransactionID()); (getRemoteNode(), getTransactionID()));
getMessageDispatcher()->addMessageToQueue(reply);
} }
std::shared_ptr<Dict> DHTAnnouncePeerMessage::getArgument() std::shared_ptr<Dict> DHTAnnouncePeerMessage::getArgument()
{ {
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH)); aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH));
aDict->put(PORT, Integer::g(tcpPort_)); aDict->put(PORT, Integer::g(tcpPort_));

View file

@ -65,8 +65,6 @@ public:
const std::string& token, const std::string& token,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual ~DHTAnnouncePeerMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument(); virtual std::shared_ptr<Dict> getArgument();

View file

@ -41,6 +41,7 @@
#include "DHTMessageFactory.h" #include "DHTMessageFactory.h"
#include "DHTMessageDispatcher.h" #include "DHTMessageDispatcher.h"
#include "DHTMessageCallback.h" #include "DHTMessageCallback.h"
#include "DHTFindNodeReplyMessage.h"
#include "util.h" #include "util.h"
namespace aria2 { namespace aria2 {
@ -49,30 +50,28 @@ const std::string DHTFindNodeMessage::FIND_NODE("find_node");
const std::string DHTFindNodeMessage::TARGET_NODE("target"); const std::string DHTFindNodeMessage::TARGET_NODE("target");
DHTFindNodeMessage::DHTFindNodeMessage(const std::shared_ptr<DHTNode>& localNode, DHTFindNodeMessage::DHTFindNodeMessage
const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& localNode,
const unsigned char* targetNodeID, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID): const unsigned char* targetNodeID,
DHTQueryMessage(localNode, remoteNode, transactionID) const std::string& transactionID)
: DHTQueryMessage{localNode, remoteNode, transactionID}
{ {
memcpy(targetNodeID_, targetNodeID, DHT_ID_LENGTH); memcpy(targetNodeID_, targetNodeID, DHT_ID_LENGTH);
} }
DHTFindNodeMessage::~DHTFindNodeMessage() {}
void DHTFindNodeMessage::doReceivedAction() void DHTFindNodeMessage::doReceivedAction()
{ {
std::vector<std::shared_ptr<DHTNode> > nodes; std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, targetNodeID_); getRoutingTable()->getClosestKNodes(nodes, targetNodeID_);
std::shared_ptr<DHTMessage> reply = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createFindNodeReplyMessage (getMessageFactory()->createFindNodeReplyMessage
(getRemoteNode(), nodes, getTransactionID()); (getRemoteNode(), std::move(nodes), getTransactionID()));
getMessageDispatcher()->addMessageToQueue(reply);
} }
std::shared_ptr<Dict> DHTFindNodeMessage::getArgument() std::shared_ptr<Dict> DHTFindNodeMessage::getArgument()
{ {
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(TARGET_NODE, String::g(targetNodeID_, DHT_ID_LENGTH)); aDict->put(TARGET_NODE, String::g(targetNodeID_, DHT_ID_LENGTH));
return aDict; return aDict;

View file

@ -51,8 +51,6 @@ public:
const unsigned char* targetNodeID, const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual ~DHTFindNodeMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument(); virtual std::shared_ptr<Dict> getArgument();

View file

@ -57,25 +57,23 @@ DHTFindNodeReplyMessage::DHTFindNodeReplyMessage
(int family, (int family,
const std::shared_ptr<DHTNode>& localNode, const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID): const std::string& transactionID)
DHTResponseMessage(localNode, remoteNode, transactionID), : DHTResponseMessage{localNode, remoteNode, transactionID},
family_(family) {} family_{family}
{}
DHTFindNodeReplyMessage::~DHTFindNodeReplyMessage() {}
void DHTFindNodeReplyMessage::doReceivedAction() void DHTFindNodeReplyMessage::doReceivedAction()
{ {
for(std::vector<std::shared_ptr<DHTNode> >::iterator i = closestKNodes_.begin(), for(auto& node : closestKNodes_) {
eoi = closestKNodes_.end(); i != eoi; ++i) { if(memcmp(node->getID(), getLocalNode()->getID(), DHT_ID_LENGTH) != 0) {
if(memcmp((*i)->getID(), getLocalNode()->getID(), DHT_ID_LENGTH) != 0) { getRoutingTable()->addNode(node);
getRoutingTable()->addNode(*i);
} }
} }
} }
std::shared_ptr<Dict> DHTFindNodeReplyMessage::getResponse() std::shared_ptr<Dict> DHTFindNodeReplyMessage::getResponse()
{ {
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
unsigned char buffer[DHTBucket::K*38]; unsigned char buffer[DHTBucket::K*38];
const int clen = bittorrent::getCompactLength(family_); const int clen = bittorrent::getCompactLength(family_);
@ -83,14 +81,12 @@ std::shared_ptr<Dict> DHTFindNodeReplyMessage::getResponse()
assert(unit <= 38); assert(unit <= 38);
size_t offset = 0; size_t offset = 0;
size_t k = 0; size_t k = 0;
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i = for(auto i = std::begin(closestKNodes_), eoi = std::end(closestKNodes_);
closestKNodes_.begin(), eoi = closestKNodes_.end();
i != eoi && k < DHTBucket::K; ++i) { i != eoi && k < DHTBucket::K; ++i) {
std::shared_ptr<DHTNode> node = *i; memcpy(buffer+offset, (*i)->getID(), DHT_ID_LENGTH);
memcpy(buffer+offset, node->getID(), DHT_ID_LENGTH);
unsigned char compact[COMPACT_LEN_IPV6]; unsigned char compact[COMPACT_LEN_IPV6];
int compactlen = bittorrent::packcompact int compactlen = bittorrent::packcompact(compact, (*i)->getIPAddress(),
(compact, node->getIPAddress(), node->getPort()); (*i)->getPort());
if(compactlen == clen) { if(compactlen == clen) {
memcpy(buffer+20+offset, compact, compactlen); memcpy(buffer+20+offset, compact, compactlen);
offset += unit; offset += unit;
@ -112,9 +108,9 @@ void DHTFindNodeReplyMessage::accept(DHTMessageCallback* callback)
} }
void DHTFindNodeReplyMessage::setClosestKNodes void DHTFindNodeReplyMessage::setClosestKNodes
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes) (std::vector<std::shared_ptr<DHTNode>> closestKNodes)
{ {
closestKNodes_ = closestKNodes; closestKNodes_ = std::move(closestKNodes);
} }
std::string DHTFindNodeReplyMessage::toStringOptional() const std::string DHTFindNodeReplyMessage::toStringOptional() const

View file

@ -44,7 +44,7 @@ class DHTFindNodeReplyMessage:public DHTResponseMessage {
private: private:
int family_; int family_;
std::vector<std::shared_ptr<DHTNode> > closestKNodes_; std::vector<std::shared_ptr<DHTNode>> closestKNodes_;
protected: protected:
virtual std::string toStringOptional() const; virtual std::string toStringOptional() const;
public: public:
@ -53,8 +53,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID); const std::string& transactionID);
virtual ~DHTFindNodeReplyMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse(); virtual std::shared_ptr<Dict> getResponse();
@ -63,13 +61,12 @@ public:
virtual void accept(DHTMessageCallback* callback); virtual void accept(DHTMessageCallback* callback);
const std::vector<std::shared_ptr<DHTNode> >& getClosestKNodes() const const std::vector<std::shared_ptr<DHTNode>>& getClosestKNodes() const
{ {
return closestKNodes_; return closestKNodes_;
} }
void setClosestKNodes void setClosestKNodes(std::vector<std::shared_ptr<DHTNode>> closestKNodes);
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes);
static const std::string FIND_NODE; static const std::string FIND_NODE;

View file

@ -44,6 +44,7 @@
#include "DHTPeerAnnounceStorage.h" #include "DHTPeerAnnounceStorage.h"
#include "Peer.h" #include "Peer.h"
#include "DHTTokenTracker.h" #include "DHTTokenTracker.h"
#include "DHTGetPeersReplyMessage.h"
#include "util.h" #include "util.h"
namespace aria2 { namespace aria2 {
@ -52,37 +53,36 @@ const std::string DHTGetPeersMessage::GET_PEERS("get_peers");
const std::string DHTGetPeersMessage::INFO_HASH("info_hash"); const std::string DHTGetPeersMessage::INFO_HASH("info_hash");
DHTGetPeersMessage::DHTGetPeersMessage(const std::shared_ptr<DHTNode>& localNode, DHTGetPeersMessage::DHTGetPeersMessage
const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& localNode,
const unsigned char* infoHash, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID): const unsigned char* infoHash,
DHTQueryMessage(localNode, remoteNode, transactionID), const std::string& transactionID)
peerAnnounceStorage_(0), : DHTQueryMessage{localNode, remoteNode, transactionID},
tokenTracker_(0) peerAnnounceStorage_{nullptr},
tokenTracker_{nullptr}
{ {
memcpy(infoHash_, infoHash, DHT_ID_LENGTH); memcpy(infoHash_, infoHash, DHT_ID_LENGTH);
} }
DHTGetPeersMessage::~DHTGetPeersMessage() {}
void DHTGetPeersMessage::doReceivedAction() void DHTGetPeersMessage::doReceivedAction()
{ {
std::string token = tokenTracker_->generateToken std::string token = tokenTracker_->generateToken
(infoHash_, getRemoteNode()->getIPAddress(), getRemoteNode()->getPort()); (infoHash_, getRemoteNode()->getIPAddress(), getRemoteNode()->getPort());
// Check to see localhost has the contents which has same infohash // Check to see localhost has the contents which has same infohash
std::vector<std::shared_ptr<Peer> > peers; std::vector<std::shared_ptr<Peer>> peers;
peerAnnounceStorage_->getPeers(peers, infoHash_); peerAnnounceStorage_->getPeers(peers, infoHash_);
std::vector<std::shared_ptr<DHTNode> > nodes; std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, infoHash_); getRoutingTable()->getClosestKNodes(nodes, infoHash_);
std::shared_ptr<DHTMessage> reply = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createGetPeersReplyMessage (getMessageFactory()->createGetPeersReplyMessage
(getRemoteNode(), nodes, peers, token, getTransactionID()); (getRemoteNode(), std::move(nodes), std::move(peers), token,
getMessageDispatcher()->addMessageToQueue(reply); getTransactionID()));
} }
std::shared_ptr<Dict> DHTGetPeersMessage::getArgument() std::shared_ptr<Dict> DHTGetPeersMessage::getArgument()
{ {
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH)); aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH));
return aDict; return aDict;

View file

@ -59,8 +59,6 @@ public:
const unsigned char* infoHash, const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual ~DHTGetPeersMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument(); virtual std::shared_ptr<Dict> getArgument();

View file

@ -64,12 +64,11 @@ DHTGetPeersReplyMessage::DHTGetPeersReplyMessage
const std::shared_ptr<DHTNode>& localNode, const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& token, const std::string& token,
const std::string& transactionID): const std::string& transactionID)
DHTResponseMessage(localNode, remoteNode, transactionID), : DHTResponseMessage{localNode, remoteNode, transactionID},
family_(family), family_{family},
token_(token) {} token_{token}
{}
DHTGetPeersReplyMessage::~DHTGetPeersReplyMessage() {}
void DHTGetPeersReplyMessage::doReceivedAction() void DHTGetPeersReplyMessage::doReceivedAction()
{ {
@ -78,7 +77,7 @@ void DHTGetPeersReplyMessage::doReceivedAction()
std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse() std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
{ {
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); rDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
rDict->put(TOKEN, token_); rDict->put(TOKEN, token_);
// TODO want parameter // TODO want parameter
@ -88,14 +87,12 @@ std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
const int unit = clen+20; const int unit = clen+20;
size_t offset = 0; size_t offset = 0;
size_t k = 0; size_t k = 0;
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i = for(auto i = std::begin(closestKNodes_), eoi = std::end(closestKNodes_);
closestKNodes_.begin(), eoi = closestKNodes_.end();
i != eoi && k < DHTBucket::K; ++i) { i != eoi && k < DHTBucket::K; ++i) {
std::shared_ptr<DHTNode> node = *i; memcpy(buffer+offset, (*i)->getID(), DHT_ID_LENGTH);
memcpy(buffer+offset, node->getID(), DHT_ID_LENGTH);
unsigned char compact[COMPACT_LEN_IPV6]; unsigned char compact[COMPACT_LEN_IPV6];
int compactlen = bittorrent::packcompact int compactlen = bittorrent::packcompact
(compact, node->getIPAddress(), node->getPort()); (compact, (*i)->getIPAddress(), (*i)->getPort());
if(compactlen == clen) { if(compactlen == clen) {
memcpy(buffer+20+offset, compact, compactlen); memcpy(buffer+20+offset, compact, compactlen);
offset += unit; offset += unit;
@ -128,15 +125,13 @@ std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
// template may get bigger than 395 bytes. So we use 25 as maximum // template may get bigger than 395 bytes. So we use 25 as maximum
// number of peer info that a message can carry. // number of peer info that a message can carry.
static const size_t MAX_VALUES_SIZE = 25; static const size_t MAX_VALUES_SIZE = 25;
std::shared_ptr<List> valuesList = List::g(); auto valuesList = List::g();
for(std::vector<std::shared_ptr<Peer> >::const_iterator i = values_.begin(), for(auto i = std::begin(values_), eoi = std::end(values_);
eoi = values_.end(); i != eoi && valuesList->size() < MAX_VALUES_SIZE; i != eoi && valuesList->size() < MAX_VALUES_SIZE; ++i) {
++i) {
const std::shared_ptr<Peer>& peer = *i;
unsigned char compact[COMPACT_LEN_IPV6]; unsigned char compact[COMPACT_LEN_IPV6];
const int clen = bittorrent::getCompactLength(family_); const int clen = bittorrent::getCompactLength(family_);
int compactlen = bittorrent::packcompact int compactlen = bittorrent::packcompact
(compact, peer->getIPAddress(), peer->getPort()); (compact, (*i)->getIPAddress(), (*i)->getPort());
if(compactlen == clen) { if(compactlen == clen) {
valuesList->append(String::g(compact, compactlen)); valuesList->append(String::g(compact, compactlen));
} }
@ -164,4 +159,16 @@ std::string DHTGetPeersReplyMessage::toStringOptional() const
static_cast<unsigned long>(closestKNodes_.size())); static_cast<unsigned long>(closestKNodes_.size()));
} }
void DHTGetPeersReplyMessage::setClosestKNodes
(std::vector<std::shared_ptr<DHTNode>> closestKNodes)
{
closestKNodes_ = std::move(closestKNodes);
}
void DHTGetPeersReplyMessage::setValues
(std::vector<std::shared_ptr<Peer>> peers)
{
values_ = std::move(peers);
}
} // namespace aria2 } // namespace aria2

View file

@ -51,9 +51,9 @@ private:
std::string token_; std::string token_;
std::vector<std::shared_ptr<DHTNode> > closestKNodes_; std::vector<std::shared_ptr<DHTNode>> closestKNodes_;
std::vector<std::shared_ptr<Peer> > values_; std::vector<std::shared_ptr<Peer>> values_;
protected: protected:
virtual std::string toStringOptional() const; virtual std::string toStringOptional() const;
public: public:
@ -63,8 +63,6 @@ public:
const std::string& token, const std::string& token,
const std::string& transactionID); const std::string& transactionID);
virtual ~DHTGetPeersReplyMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse(); virtual std::shared_ptr<Dict> getResponse();
@ -73,26 +71,19 @@ public:
virtual void accept(DHTMessageCallback* callback); virtual void accept(DHTMessageCallback* callback);
const std::vector<std::shared_ptr<DHTNode> >& getClosestKNodes() const const std::vector<std::shared_ptr<DHTNode>>& getClosestKNodes() const
{ {
return closestKNodes_; return closestKNodes_;
} }
const std::vector<std::shared_ptr<Peer> >& getValues() const const std::vector<std::shared_ptr<Peer>>& getValues() const
{ {
return values_; return values_;
} }
void setClosestKNodes void setClosestKNodes(std::vector<std::shared_ptr<DHTNode>> closestKNodes);
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes)
{
closestKNodes_ = closestKNodes;
}
void setValues(const std::vector<std::shared_ptr<Peer> >& peers) void setValues(std::vector<std::shared_ptr<Peer>> peers);
{
values_ = peers;
}
const std::string& getToken() const const std::string& getToken() const
{ {

View file

@ -37,6 +37,7 @@
#include "common.h" #include "common.h"
#include "DHTNodeLookupEntry.h" #include "DHTNodeLookupEntry.h"
#include "DHTNode.h"
#include "DHTConstants.h" #include "DHTConstants.h"
#include "XORCloser.h" #include "XORCloser.h"
@ -46,10 +47,12 @@ class DHTIDCloser {
private: private:
XORCloser closer_; XORCloser closer_;
public: public:
DHTIDCloser(const unsigned char* targetID):closer_(targetID, DHT_ID_LENGTH) {} DHTIDCloser(const unsigned char* targetID)
: closer_{targetID, DHT_ID_LENGTH}
{}
bool operator()(const std::shared_ptr<DHTNodeLookupEntry>& m1, bool operator()(const std::unique_ptr<DHTNodeLookupEntry>& m1,
const std::shared_ptr<DHTNodeLookupEntry>& m2) const const std::unique_ptr<DHTNodeLookupEntry>& m2) const
{ {
return closer_(m1->node->getID(), m2->node->getID()); return closer_(m1->node->getID(), m2->node->getID());
} }

View file

@ -48,8 +48,10 @@ const std::string DHTMessage::ID("id");
DHTMessage::DHTMessage(const std::shared_ptr<DHTNode>& localNode, DHTMessage::DHTMessage(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID): const std::string& transactionID)
localNode_(localNode), remoteNode_(remoteNode), transactionID_(transactionID) : localNode_{localNode},
remoteNode_{remoteNode},
transactionID_{transactionID}
{ {
if(transactionID.empty()) { if(transactionID.empty()) {
generateTransactionID(); generateTransactionID();

View file

@ -53,7 +53,7 @@ class DHTMessageCallback {
public: public:
virtual ~DHTMessageCallback() {} virtual ~DHTMessageCallback() {}
void onReceived(const std::shared_ptr<DHTResponseMessage>& message) void onReceived(DHTResponseMessage* message)
{ {
message->accept(this); message->accept(this);
} }

View file

@ -51,15 +51,15 @@ public:
virtual ~DHTMessageDispatcher() {} virtual ~DHTMessageDispatcher() {}
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()) = 0; std::unique_ptr<DHTMessageCallback>{}) = 0;
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()) = 0; std::unique_ptr<DHTMessageCallback>{}) = 0;
virtual void sendMessages() = 0; virtual void sendMessages() = 0;

View file

@ -43,44 +43,41 @@
#include "DHTConstants.h" #include "DHTConstants.h"
#include "fmt.h" #include "fmt.h"
#include "DHTNode.h" #include "DHTNode.h"
#include "a2functional.h"
namespace aria2 { namespace aria2 {
DHTMessageDispatcherImpl::DHTMessageDispatcherImpl DHTMessageDispatcherImpl::DHTMessageDispatcherImpl
(const std::shared_ptr<DHTMessageTracker>& tracker) (const std::shared_ptr<DHTMessageTracker>& tracker)
: tracker_(tracker), : tracker_{tracker},
timeout_(DHT_MESSAGE_TIMEOUT) timeout_{DHT_MESSAGE_TIMEOUT}
{} {}
DHTMessageDispatcherImpl::~DHTMessageDispatcherImpl() {}
void void
DHTMessageDispatcherImpl::addMessageToQueue DHTMessageDispatcherImpl::addMessageToQueue
(const std::shared_ptr<DHTMessage>& message, (std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback) std::unique_ptr<DHTMessageCallback> callback)
{ {
std::shared_ptr<DHTMessageEntry> e messageQueue_.push_back(make_unique<DHTMessageEntry>
(new DHTMessageEntry(message, timeout, callback)); (std::move(message), timeout, std::move(callback)));
messageQueue_.push_back(e);
} }
void void
DHTMessageDispatcherImpl::addMessageToQueue DHTMessageDispatcherImpl::addMessageToQueue
(const std::shared_ptr<DHTMessage>& message, (std::unique_ptr<DHTMessage> message,
const std::shared_ptr<DHTMessageCallback>& callback) std::unique_ptr<DHTMessageCallback> callback)
{ {
addMessageToQueue(message, timeout_, callback); addMessageToQueue(std::move(message), timeout_, std::move(callback));
} }
bool bool DHTMessageDispatcherImpl::sendMessage(DHTMessageEntry* entry)
DHTMessageDispatcherImpl::sendMessage
(const std::shared_ptr<DHTMessageEntry>& entry)
{ {
try { try {
if(entry->message->send()) { if(entry->message->send()) {
if(!entry->message->isReply()) { if(!entry->message->isReply()) {
tracker_->addMessage(entry->message, entry->timeout, entry->callback); tracker_->addMessage(entry->message.get(), entry->timeout,
std::move(entry->callback));
} }
A2_LOG_INFO(fmt("Message sent: %s", entry->message->toString().c_str())); A2_LOG_INFO(fmt("Message sent: %s", entry->message->toString().c_str()));
} else { } else {
@ -95,7 +92,8 @@ DHTMessageDispatcherImpl::sendMessage
// DHTTask(such as DHTAbstractNodeLookupTask) don't finish // DHTTask(such as DHTAbstractNodeLookupTask) don't finish
// forever. // forever.
if(!entry->message->isReply()) { if(!entry->message->isReply()) {
tracker_->addMessage(entry->message, 0, entry->callback); tracker_->addMessage(entry->message.get(), 0,
std::move(entry->callback));
} }
} }
return true; return true;
@ -103,13 +101,13 @@ DHTMessageDispatcherImpl::sendMessage
void DHTMessageDispatcherImpl::sendMessages() void DHTMessageDispatcherImpl::sendMessages()
{ {
auto itr = messageQueue_.begin(); auto itr = std::begin(messageQueue_);
for(; itr != messageQueue_.end(); ++itr) { for(; itr != std::end(messageQueue_); ++itr) {
if(!sendMessage(*itr)) { if(!sendMessage((*itr).get())) {
break; break;
} }
} }
messageQueue_.erase(messageQueue_.begin(), itr); messageQueue_.erase(std::begin(messageQueue_), itr);
A2_LOG_DEBUG(fmt("%lu dht messages remaining in the queue.", A2_LOG_DEBUG(fmt("%lu dht messages remaining in the queue.",
static_cast<unsigned long>(messageQueue_.size()))); static_cast<unsigned long>(messageQueue_.size())));
} }

View file

@ -47,26 +47,24 @@ class DHTMessageDispatcherImpl:public DHTMessageDispatcher {
private: private:
std::shared_ptr<DHTMessageTracker> tracker_; std::shared_ptr<DHTMessageTracker> tracker_;
std::deque<std::shared_ptr<DHTMessageEntry> > messageQueue_; std::deque<std::unique_ptr<DHTMessageEntry>> messageQueue_;
time_t timeout_; time_t timeout_;
bool sendMessage(const std::shared_ptr<DHTMessageEntry>& msg); bool sendMessage(DHTMessageEntry* msg);
public: public:
DHTMessageDispatcherImpl(const std::shared_ptr<DHTMessageTracker>& tracker); DHTMessageDispatcherImpl(const std::shared_ptr<DHTMessageTracker>& tracker);
virtual ~DHTMessageDispatcherImpl();
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()); std::unique_ptr<DHTMessageCallback>{});
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()); std::unique_ptr<DHTMessageCallback>{});
virtual void sendMessages(); virtual void sendMessages();

View file

@ -40,13 +40,12 @@
namespace aria2 { namespace aria2 {
DHTMessageEntry::DHTMessageEntry DHTMessageEntry::DHTMessageEntry
(const std::shared_ptr<DHTMessage>& message, (std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback): std::unique_ptr<DHTMessageCallback> callback)
message(message), : message{std::move(message)},
timeout(timeout), timeout{timeout},
callback(callback) {} callback{std::move(callback)}
{}
DHTMessageEntry::~DHTMessageEntry() {}
} // namespace aria2 } // namespace aria2

View file

@ -47,15 +47,13 @@ class DHTMessage;
class DHTMessageCallback; class DHTMessageCallback;
struct DHTMessageEntry { struct DHTMessageEntry {
std::shared_ptr<DHTMessage> message; std::unique_ptr<DHTMessage> message;
time_t timeout; time_t timeout;
std::shared_ptr<DHTMessageCallback> callback; std::unique_ptr<DHTMessageCallback> callback;
DHTMessageEntry(const std::shared_ptr<DHTMessage>& message, DHTMessageEntry(std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback); std::unique_ptr<DHTMessageCallback> callback);
~DHTMessageEntry();
}; };
} // namespace aria2 } // namespace aria2

View file

@ -49,6 +49,15 @@ namespace aria2 {
class DHTMessage; class DHTMessage;
class DHTQueryMessage; class DHTQueryMessage;
class DHTResponseMessage; class DHTResponseMessage;
class DHTPingMessage;
class DHTPingReplyMessage;
class DHTFindNodeMessage;
class DHTFindNodeReplyMessage;
class DHTGetPeersMessage;
class DHTGetPeersReplyMessage;
class DHTAnnouncePeerMessage;
class DHTAnnouncePeerReplyMessage;
class DHTUnknownMessage;
class DHTNode; class DHTNode;
class Peer; class Peer;
@ -56,60 +65,60 @@ class DHTMessageFactory {
public: public:
virtual ~DHTMessageFactory() {} virtual ~DHTMessageFactory() {}
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict, createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port) = 0; const std::string& ipaddr, uint16_t port) = 0;
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType, createResponseMessage(const std::string& messageType,
const Dict* dict, const Dict* dict,
const std::string& ipaddr, uint16_t port) = 0; const std::string& ipaddr, uint16_t port) = 0;
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL) = 0; const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id, const unsigned char* id,
const std::string& transactionID) = 0; const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode, createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID, const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL) = 0; const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) = 0; const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode, createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL) = 0; const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers, std::vector<std::shared_ptr<Peer>> peers,
const std::string& token, const std::string& token,
const std::string& transactionID) = 0; const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
uint16_t tcpPort, uint16_t tcpPort,
const std::string& token, const std::string& token,
const std::string& transactionID = A2STR::NIL) = 0; const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID) = 0; const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTMessage> virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length, createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port) = 0; const std::string& ipaddr, uint16_t port) = 0;
}; };

View file

@ -65,23 +65,21 @@
namespace aria2 { namespace aria2 {
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family) DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
: family_(family), : family_{family},
connection_(0), connection_{nullptr},
dispatcher_(0), dispatcher_{nullptr},
routingTable_(0), routingTable_{nullptr},
peerAnnounceStorage_(0), peerAnnounceStorage_{nullptr},
tokenTracker_(0) tokenTracker_{nullptr}
{} {}
DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {}
std::shared_ptr<DHTNode> std::shared_ptr<DHTNode>
DHTMessageFactoryImpl::getRemoteNode DHTMessageFactoryImpl::getRemoteNode
(const unsigned char* id, const std::string& ipaddr, uint16_t port) const (const unsigned char* id, const std::string& ipaddr, uint16_t port) const
{ {
std::shared_ptr<DHTNode> node = routingTable_->getNode(id, ipaddr, port); auto node = routingTable_->getNode(id, ipaddr, port);
if(!node) { if(!node) {
node.reset(new DHTNode(id)); node = std::make_shared<DHTNode>(id);
node->setIPAddress(ipaddr); node->setIPAddress(ipaddr);
node->setPort(port); node->setPort(port);
} }
@ -188,7 +186,7 @@ void DHTMessageFactoryImpl::validatePort(const Integer* port) const
} }
namespace { namespace {
void setVersion(const std::shared_ptr<DHTMessage>& msg, const Dict* dict) void setVersion(DHTMessage* msg, const Dict* dict)
{ {
const String* v = downcast<String>(dict->get(DHTMessage::V)); const String* v = downcast<String>(dict->get(DHTMessage::V));
if(v) { if(v) {
@ -199,7 +197,7 @@ void setVersion(const std::shared_ptr<DHTMessage>& msg, const Dict* dict)
} }
} // namespace } // namespace
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage std::unique_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
(const Dict* dict, const std::string& ipaddr, uint16_t port) (const Dict* dict, const std::string& ipaddr, uint16_t port)
{ {
const String* messageType = getString(dict, DHTQueryMessage::Q); const String* messageType = getString(dict, DHTQueryMessage::Q);
@ -211,8 +209,8 @@ std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
} }
const String* id = getString(aDict, DHTMessage::ID); const String* id = getString(aDict, DHTMessage::ID);
validateID(id); validateID(id);
std::shared_ptr<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port); auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
std::shared_ptr<DHTQueryMessage> msg; auto msg = std::unique_ptr<DHTQueryMessage>{};
if(messageType->s() == DHTPingMessage::PING) { if(messageType->s() == DHTPingMessage::PING) {
msg = createPingMessage(remoteNode, transactionID->s()); msg = createPingMessage(remoteNode, transactionID->s());
} else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) { } else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) {
@ -238,11 +236,11 @@ std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
throw DL_ABORT_EX(fmt("Unsupported message type: %s", throw DL_ABORT_EX(fmt("Unsupported message type: %s",
messageType->s().c_str())); messageType->s().c_str()));
} }
setVersion(msg, dict); setVersion(msg.get(), dict);
return msg; return msg;
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTResponseMessage>
DHTMessageFactoryImpl::createResponseMessage DHTMessageFactoryImpl::createResponseMessage
(const std::string& messageType, (const std::string& messageType,
const Dict* dict, const Dict* dict,
@ -270,8 +268,8 @@ DHTMessageFactoryImpl::createResponseMessage
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R); const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* id = getString(rDict, DHTMessage::ID); const String* id = getString(rDict, DHTMessage::ID);
validateID(id); validateID(id);
std::shared_ptr<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port); auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
std::shared_ptr<DHTResponseMessage> msg; auto msg = std::unique_ptr<DHTResponseMessage>{};
if(messageType == DHTPingReplyMessage::PING) { if(messageType == DHTPingReplyMessage::PING) {
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s()); msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
} else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) { } else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) {
@ -284,7 +282,7 @@ DHTMessageFactoryImpl::createResponseMessage
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt("Unsupported message type: %s", messageType.c_str())); (fmt("Unsupported message type: %s", messageType.c_str()));
} }
setVersion(msg, dict); setVersion(msg.get(), dict);
return msg; return msg;
} }
@ -312,51 +310,53 @@ void DHTMessageFactoryImpl::setCommonProperty(DHTAbstractMessage* m)
m->setVersion(getDefaultVersion()); m->setVersion(getDefaultVersion());
} }
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createPingMessage std::unique_ptr<DHTPingMessage> DHTMessageFactoryImpl::createPingMessage
(const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID) (const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID)
{ {
DHTPingMessage* m(new DHTPingMessage(localNode_, remoteNode, transactionID)); auto m = make_unique<DHTPingMessage>(localNode_, remoteNode, transactionID);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTQueryMessage>(m); return m;
} }
std::shared_ptr<DHTResponseMessage> DHTMessageFactoryImpl::createPingReplyMessage std::unique_ptr<DHTPingReplyMessage>
DHTMessageFactoryImpl::createPingReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id, const unsigned char* id,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTPingReplyMessage* m auto m = make_unique<DHTPingReplyMessage>(localNode_, remoteNode, id,
(new DHTPingReplyMessage(localNode_, remoteNode, id, transactionID)); transactionID);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTResponseMessage>(m); return m;
} }
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createFindNodeMessage std::unique_ptr<DHTFindNodeMessage>
DHTMessageFactoryImpl::createFindNodeMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID, const unsigned char* targetNodeID,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTFindNodeMessage* m(new DHTFindNodeMessage auto m = make_unique<DHTFindNodeMessage>(localNode_, remoteNode,
(localNode_, remoteNode, targetNodeID, transactionID)); targetNodeID, transactionID);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTQueryMessage>(m); return m;
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage DHTMessageFactoryImpl::createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTFindNodeReplyMessage* m(new DHTFindNodeReplyMessage auto m = make_unique<DHTFindNodeReplyMessage>(family_, localNode_,
(family_, localNode_, remoteNode, transactionID)); remoteNode, transactionID);
m->setClosestKNodes(closestKNodes); m->setClosestKNodes(std::move(closestKNodes));
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTResponseMessage>(m); return m;
} }
void DHTMessageFactoryImpl::extractNodes void DHTMessageFactoryImpl::extractNodes
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode>>& nodes,
const unsigned char* src, size_t length) const unsigned char* src, size_t length)
{ {
int unit = bittorrent::getCompactLength(family_)+20; int unit = bittorrent::getCompactLength(family_)+20;
@ -365,19 +365,18 @@ void DHTMessageFactoryImpl::extractNodes
(fmt("Nodes length is not multiple of %d", unit)); (fmt("Nodes length is not multiple of %d", unit));
} }
for(size_t offset = 0; offset < length; offset += unit) { for(size_t offset = 0; offset < length; offset += unit) {
std::shared_ptr<DHTNode> node(new DHTNode(src+offset)); auto node = std::make_shared<DHTNode>(src+offset);
std::pair<std::string, uint16_t> addr = auto addr = bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
if(addr.first.empty()) { if(addr.first.empty()) {
continue; continue;
} }
node->setIPAddress(addr.first); node->setIPAddress(addr.first);
node->setPort(addr.second); node->setPort(addr.second);
nodes.push_back(node); nodes.push_back(std::move(node));
} }
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage DHTMessageFactoryImpl::createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict, const Dict* dict,
@ -387,28 +386,29 @@ DHTMessageFactoryImpl::createFindNodeReplyMessage
downcast<String>(getDictionary(dict, DHTResponseMessage::R)-> downcast<String>(getDictionary(dict, DHTResponseMessage::R)->
get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES: get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES:
DHTFindNodeReplyMessage::NODES6)); DHTFindNodeReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode> > nodes; std::vector<std::shared_ptr<DHTNode>> nodes;
if(nodesData) { if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size()); extractNodes(nodes, nodesData->uc(), nodesData->s().size());
} }
return createFindNodeReplyMessage(remoteNode, nodes, transactionID); return createFindNodeReplyMessage(remoteNode, std::move(nodes),
transactionID);
} }
std::shared_ptr<DHTQueryMessage> std::unique_ptr<DHTGetPeersMessage>
DHTMessageFactoryImpl::createGetPeersMessage DHTMessageFactoryImpl::createGetPeersMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTGetPeersMessage* m auto m = make_unique<DHTGetPeersMessage>(localNode_, remoteNode, infoHash,
(new DHTGetPeersMessage(localNode_, remoteNode, infoHash, transactionID)); transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_); m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_); m->setTokenTracker(tokenTracker_);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTQueryMessage>(m); return m;
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage DHTMessageFactoryImpl::createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict, const Dict* dict,
@ -416,54 +416,53 @@ DHTMessageFactoryImpl::createGetPeersReplyMessage
{ {
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R); const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* nodesData = const String* nodesData =
downcast<String>(rDict->get(family_ == AF_INET?DHTGetPeersReplyMessage::NODES: downcast<String>(rDict->get(family_ == AF_INET ?
DHTGetPeersReplyMessage::NODES6)); DHTGetPeersReplyMessage::NODES :
std::vector<std::shared_ptr<DHTNode> > nodes; DHTGetPeersReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode>> nodes;
if(nodesData) { if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size()); extractNodes(nodes, nodesData->uc(), nodesData->s().size());
} }
const List* valuesList = const List* valuesList =
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES)); downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
std::vector<std::shared_ptr<Peer> > peers; std::vector<std::shared_ptr<Peer>> peers;
size_t clen = bittorrent::getCompactLength(family_); size_t clen = bittorrent::getCompactLength(family_);
if(valuesList) { if(valuesList) {
for(List::ValueType::const_iterator i = valuesList->begin(), for(auto i = valuesList->begin(), eoi = valuesList->end(); i != eoi; ++i) {
eoi = valuesList->end(); i != eoi; ++i) {
const String* data = downcast<String>(*i); const String* data = downcast<String>(*i);
if(data && data->s().size() == clen) { if(data && data->s().size() == clen) {
std::pair<std::string, uint16_t> addr = auto addr = bittorrent::unpackcompact(data->uc(), family_);
bittorrent::unpackcompact(data->uc(), family_);
if(addr.first.empty()) { if(addr.first.empty()) {
continue; continue;
} }
std::shared_ptr<Peer> peer(new Peer(addr.first, addr.second)); peers.push_back(std::make_shared<Peer>(addr.first, addr.second));
peers.push_back(peer);
} }
} }
} }
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN); const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
return createGetPeersReplyMessage return createGetPeersReplyMessage
(remoteNode, nodes, peers, token->s(), transactionID); (remoteNode, std::move(nodes), std::move(peers), token->s(),
transactionID);
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage DHTMessageFactoryImpl::createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::vector<std::shared_ptr<Peer> >& values, std::vector<std::shared_ptr<Peer>> values,
const std::string& token, const std::string& token,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTGetPeersReplyMessage* m(new DHTGetPeersReplyMessage auto m = make_unique<DHTGetPeersReplyMessage>(family_, localNode_,
(family_, localNode_, remoteNode, token, remoteNode, token,
transactionID)); transactionID);
m->setClosestKNodes(closestKNodes); m->setClosestKNodes(std::move(closestKNodes));
m->setValues(values); m->setValues(std::move(values));
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTResponseMessage>(m); return m;
} }
std::shared_ptr<DHTQueryMessage> std::unique_ptr<DHTAnnouncePeerMessage>
DHTMessageFactoryImpl::createAnnouncePeerMessage DHTMessageFactoryImpl::createAnnouncePeerMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
@ -471,34 +470,33 @@ DHTMessageFactoryImpl::createAnnouncePeerMessage
const std::string& token, const std::string& token,
const std::string& transactionID) const std::string& transactionID)
{ {
DHTAnnouncePeerMessage* m(new DHTAnnouncePeerMessage auto m = make_unique<DHTAnnouncePeerMessage>(localNode_, remoteNode,
(localNode_, remoteNode, infoHash, tcpPort, token, infoHash, tcpPort, token,
transactionID)); transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_); m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_); m->setTokenTracker(tokenTracker_);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTQueryMessage>(m); return m;
} }
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTAnnouncePeerReplyMessage>
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage DHTMessageFactoryImpl::createAnnouncePeerReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID) (const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID)
{ {
DHTAnnouncePeerReplyMessage* m auto m = make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
(new DHTAnnouncePeerReplyMessage(localNode_, remoteNode, transactionID)); transactionID);
setCommonProperty(m); setCommonProperty(m.get());
return std::shared_ptr<DHTResponseMessage>(m); return m;
} }
std::shared_ptr<DHTMessage> std::unique_ptr<DHTUnknownMessage>
DHTMessageFactoryImpl::createUnknownMessage DHTMessageFactoryImpl::createUnknownMessage
(const unsigned char* data, size_t length, (const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port) const std::string& ipaddr, uint16_t port)
{ {
DHTUnknownMessage* m return make_unique<DHTUnknownMessage>(localNode_, data, length,
(new DHTUnknownMessage(localNode_, data, length, ipaddr, port)); ipaddr, port);
return std::shared_ptr<DHTMessage>(m);
} }
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable) void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)

View file

@ -81,74 +81,72 @@ private:
public: public:
DHTMessageFactoryImpl(int family); DHTMessageFactoryImpl(int family);
virtual ~DHTMessageFactoryImpl(); virtual std::unique_ptr<DHTQueryMessage>
virtual std::shared_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict, createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port); const std::string& ipaddr, uint16_t port);
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType, createResponseMessage(const std::string& messageType,
const Dict* dict, const Dict* dict,
const std::string& ipaddr, uint16_t port); const std::string& ipaddr, uint16_t port);
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id, const unsigned char* id,
const std::string& transactionID); const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode, createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID, const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createFindNodeReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict, const Dict* dict,
const std::string& transactionID); const std::string& transactionID);
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID); const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode, createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers, std::vector<std::shared_ptr<Peer>> peers,
const std::string& token, const std::string& token,
const std::string& transactionID); const std::string& transactionID);
std::shared_ptr<DHTResponseMessage> std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict, const Dict* dict,
const std::string& transactionID); const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
uint16_t tcpPort, uint16_t tcpPort,
const std::string& token, const std::string& token,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID); const std::string& transactionID);
virtual std::shared_ptr<DHTMessage> virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length, createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port); const std::string& ipaddr, uint16_t port);

View file

@ -58,18 +58,16 @@ namespace aria2 {
DHTMessageReceiver::DHTMessageReceiver DHTMessageReceiver::DHTMessageReceiver
(const std::shared_ptr<DHTMessageTracker>& tracker) (const std::shared_ptr<DHTMessageTracker>& tracker)
: tracker_(tracker) : tracker_{tracker}
{} {}
DHTMessageReceiver::~DHTMessageReceiver() {} std::unique_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data, (const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length) size_t length)
{ {
try { try {
bool isReply = false; bool isReply = false;
std::shared_ptr<ValueBase> decoded = bencode2::decode(data, length); auto decoded = bencode2::decode(data, length);
const Dict* dict = downcast<Dict>(decoded); const Dict* dict = downcast<Dict>(decoded);
if(dict) { if(dict) {
const String* y = downcast<String>(dict->get(DHTMessage::Y)); const String* y = downcast<String>(dict->get(DHTMessage::Y));
@ -89,28 +87,26 @@ std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
return handleUnknownMessage(data, length, remoteAddr, remotePort); return handleUnknownMessage(data, length, remoteAddr, remotePort);
} }
if(isReply) { if(isReply) {
std::pair<std::shared_ptr<DHTResponseMessage>, auto p = tracker_->messageArrived(dict, remoteAddr, remotePort);
std::shared_ptr<DHTMessageCallback> > p =
tracker_->messageArrived(dict, remoteAddr, remotePort);
if(!p.first) { if(!p.first) {
// timeout or malicious? message // timeout or malicious? message
return handleUnknownMessage(data, length, remoteAddr, remotePort); return handleUnknownMessage(data, length, remoteAddr, remotePort);
} }
onMessageReceived(p.first); onMessageReceived(p.first.get());
if(p.second) { if(p.second) {
p.second->onReceived(p.first); p.second->onReceived(p.first.get());
} }
return p.first; return std::move(p.first);
} else { } else {
std::shared_ptr<DHTQueryMessage> message = auto message =
factory_->createQueryMessage(dict, remoteAddr, remotePort); factory_->createQueryMessage(dict, remoteAddr, remotePort);
if(*message->getLocalNode() == *message->getRemoteNode()) { if(*message->getLocalNode() == *message->getRemoteNode()) {
// drop message from localnode // drop message from localnode
A2_LOG_INFO("Received DHT message from localnode."); A2_LOG_INFO("Received DHT message from localnode.");
return handleUnknownMessage(data, length, remoteAddr, remotePort); return handleUnknownMessage(data, length, remoteAddr, remotePort);
} }
onMessageReceived(message); onMessageReceived(message.get());
return message; return std::move(message);
} }
} catch(RecoverableException& e) { } catch(RecoverableException& e) {
A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e); A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e);
@ -118,8 +114,7 @@ std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
} }
} }
void DHTMessageReceiver::onMessageReceived void DHTMessageReceiver::onMessageReceived(DHTMessage* message)
(const std::shared_ptr<DHTMessage>& message)
{ {
A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str())); A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str()));
message->validate(); message->validate();
@ -134,13 +129,13 @@ void DHTMessageReceiver::handleTimeout()
tracker_->handleTimeout(); tracker_->handleTimeout();
} }
std::shared_ptr<DHTMessage> std::unique_ptr<DHTUnknownMessage>
DHTMessageReceiver::handleUnknownMessage(const unsigned char* data, DHTMessageReceiver::handleUnknownMessage(const unsigned char* data,
size_t length, size_t length,
const std::string& remoteAddr, const std::string& remoteAddr,
uint16_t remotePort) uint16_t remotePort)
{ {
std::shared_ptr<DHTMessage> m = auto m =
factory_->createUnknownMessage(data, length, remoteAddr, remotePort); factory_->createUnknownMessage(data, length, remoteAddr, remotePort);
A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str())); A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str()));
return m; return m;

View file

@ -47,6 +47,7 @@ class DHTMessage;
class DHTConnection; class DHTConnection;
class DHTMessageFactory; class DHTMessageFactory;
class DHTRoutingTable; class DHTRoutingTable;
class DHTUnknownMessage;
class DHTMessageReceiver { class DHTMessageReceiver {
private: private:
@ -58,17 +59,15 @@ private:
std::shared_ptr<DHTRoutingTable> routingTable_; std::shared_ptr<DHTRoutingTable> routingTable_;
std::shared_ptr<DHTMessage> std::unique_ptr<DHTUnknownMessage>
handleUnknownMessage(const unsigned char* data, size_t length, handleUnknownMessage(const unsigned char* data, size_t length,
const std::string& remoteAddr, uint16_t remotePort); const std::string& remoteAddr, uint16_t remotePort);
void onMessageReceived(const std::shared_ptr<DHTMessage>& message); void onMessageReceived(DHTMessage* message);
public: public:
DHTMessageReceiver(const std::shared_ptr<DHTMessageTracker>& tracker); DHTMessageReceiver(const std::shared_ptr<DHTMessageTracker>& tracker);
~DHTMessageReceiver(); std::unique_ptr<DHTMessage> receiveMessage
std::shared_ptr<DHTMessage> receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data, (const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length); size_t length);

View file

@ -51,17 +51,24 @@
namespace aria2 { namespace aria2 {
DHTMessageTracker::DHTMessageTracker() {} DHTMessageTracker::DHTMessageTracker()
: factory_{nullptr}
{}
DHTMessageTracker::~DHTMessageTracker() {} void DHTMessageTracker::addMessage
(DHTMessage* message,
void DHTMessageTracker::addMessage(const std::shared_ptr<DHTMessage>& message, time_t timeout, const std::shared_ptr<DHTMessageCallback>& callback) time_t timeout,
std::unique_ptr<DHTMessageCallback> callback)
{ {
std::shared_ptr<DHTMessageTrackerEntry> e(new DHTMessageTrackerEntry(message, timeout, callback)); entries_.push_back(make_unique<DHTMessageTrackerEntry>
entries_.push_back(e); (message->getRemoteNode(),
message->getTransactionID(),
message->getMessageType(),
timeout, std::move(callback)));
} }
std::pair<std::shared_ptr<DHTResponseMessage>, std::shared_ptr<DHTMessageCallback> > std::pair<std::unique_ptr<DHTResponseMessage>,
std::unique_ptr<DHTMessageCallback>>
DHTMessageTracker::messageArrived DHTMessageTracker::messageArrived
(const Dict* dict, const std::string& ipaddr, uint16_t port) (const Dict* dict, const std::string& ipaddr, uint16_t port)
{ {
@ -73,15 +80,14 @@ DHTMessageTracker::messageArrived
A2_LOG_DEBUG(fmt("Searching tracker entry for TransactionID=%s, Remote=%s:%u", A2_LOG_DEBUG(fmt("Searching tracker entry for TransactionID=%s, Remote=%s:%u",
util::toHex(tid->s()).c_str(), util::toHex(tid->s()).c_str(),
ipaddr.c_str(), port)); ipaddr.c_str(), port));
for(std::deque<std::shared_ptr<DHTMessageTrackerEntry> >::iterator i = for(auto i = std::begin(entries_), eoi = std::end(entries_); i != eoi; ++i) {
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) {
if((*i)->match(tid->s(), ipaddr, port)) { if((*i)->match(tid->s(), ipaddr, port)) {
std::shared_ptr<DHTMessageTrackerEntry> entry = *i; auto entry = std::move(*i);
entries_.erase(i); entries_.erase(i);
A2_LOG_DEBUG("Tracker entry found."); A2_LOG_DEBUG("Tracker entry found.");
std::shared_ptr<DHTNode> targetNode = entry->getTargetNode(); auto& targetNode = entry->getTargetNode();
try { try {
std::shared_ptr<DHTResponseMessage> message = auto message =
factory_->createResponseMessage(entry->getMessageType(), dict, factory_->createResponseMessage(entry->getMessageType(), dict,
targetNode->getIPAddress(), targetNode->getIPAddress(),
targetNode->getPort()); targetNode->getPort());
@ -89,8 +95,7 @@ DHTMessageTracker::messageArrived
int64_t rtt = entry->getElapsedMillis(); int64_t rtt = entry->getElapsedMillis();
A2_LOG_DEBUG(fmt("RTT is %" PRId64 "", rtt)); A2_LOG_DEBUG(fmt("RTT is %" PRId64 "", rtt));
message->getRemoteNode()->updateRTT(rtt); message->getRemoteNode()->updateRTT(rtt);
std::shared_ptr<DHTMessageCallback> callback = entry->getCallback(); if(*targetNode != *message->getRemoteNode()) {
if(!(*targetNode == *message->getRemoteNode())) {
// Node ID has changed. Drop previous node ID from // Node ID has changed. Drop previous node ID from
// DHTRoutingTable // DHTRoutingTable
A2_LOG_DEBUG A2_LOG_DEBUG
@ -100,23 +105,22 @@ DHTMessageTracker::messageArrived
DHT_ID_LENGTH).c_str())); DHT_ID_LENGTH).c_str()));
routingTable_->dropNode(targetNode); routingTable_->dropNode(targetNode);
} }
return std::make_pair(message, callback); return std::make_pair(std::move(message), entry->popCallback());
} catch(RecoverableException& e) { } catch(RecoverableException& e) {
handleTimeoutEntry(entry); handleTimeoutEntry(entry.get());
throw; throw;
} }
} }
} }
A2_LOG_DEBUG("Tracker entry not found."); A2_LOG_DEBUG("Tracker entry not found.");
return std::pair<std::shared_ptr<DHTResponseMessage>, return std::pair<std::unique_ptr<DHTResponseMessage>,
std::shared_ptr<DHTMessageCallback> >(); std::unique_ptr<DHTMessageCallback>>{};
} }
void DHTMessageTracker::handleTimeoutEntry void DHTMessageTracker::handleTimeoutEntry(DHTMessageTrackerEntry* entry)
(const std::shared_ptr<DHTMessageTrackerEntry>& entry)
{ {
try { try {
std::shared_ptr<DHTNode> node = entry->getTargetNode(); auto& node = entry->getTargetNode();
A2_LOG_DEBUG(fmt("Message timeout: To:%s:%u", A2_LOG_DEBUG(fmt("Message timeout: To:%s:%u",
node->getIPAddress().c_str(), node->getPort())); node->getIPAddress().c_str(), node->getPort()));
node->updateRTT(entry->getElapsedMillis()); node->updateRTT(entry->getElapsedMillis());
@ -126,7 +130,7 @@ void DHTMessageTracker::handleTimeoutEntry
node->getIPAddress().c_str(), node->getPort())); node->getIPAddress().c_str(), node->getPort()));
routingTable_->dropNode(node); routingTable_->dropNode(node);
} }
std::shared_ptr<DHTMessageCallback> callback = entry->getCallback(); auto& callback = entry->getCallback();
if(callback) { if(callback) {
callback->onTimeout(node); callback->onTimeout(node);
} }
@ -135,43 +139,33 @@ void DHTMessageTracker::handleTimeoutEntry
} }
} }
namespace {
struct HandleTimeout {
HandleTimeout(DHTMessageTracker* tracker)
: tracker(tracker)
{}
bool operator()(const std::shared_ptr<DHTMessageTrackerEntry>& ent) const
{
if(ent->isTimeout()) {
tracker->handleTimeoutEntry(ent);
return true;
} else {
return false;
}
}
DHTMessageTracker* tracker;
};
} // namespace
void DHTMessageTracker::handleTimeout() void DHTMessageTracker::handleTimeout()
{ {
entries_.erase(std::remove_if(entries_.begin(), entries_.end(), entries_.erase
HandleTimeout(this)), (std::remove_if(std::begin(entries_), std::end(entries_),
entries_.end()); [&](const std::unique_ptr<DHTMessageTrackerEntry>& ent)
{
if(ent->isTimeout()) {
handleTimeoutEntry(ent.get());
return true;
} else {
return false;
}
}),
std::end(entries_));
} }
std::shared_ptr<DHTMessageTrackerEntry> const DHTMessageTrackerEntry*
DHTMessageTracker::getEntryFor(const std::shared_ptr<DHTMessage>& message) const DHTMessageTracker::getEntryFor(const DHTMessage* message) const
{ {
for(std::deque<std::shared_ptr<DHTMessageTrackerEntry> >::const_iterator i = for(auto& ent : entries_) {
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) { if(ent->match(message->getTransactionID(),
if((*i)->match(message->getTransactionID(), message->getRemoteNode()->getIPAddress(),
message->getRemoteNode()->getIPAddress(), message->getRemoteNode()->getPort())) {
message->getRemoteNode()->getPort())) { return ent.get();
return *i;
} }
} }
return std::shared_ptr<DHTMessageTrackerEntry>(); return nullptr;
} }
size_t DHTMessageTracker::countEntry() const size_t DHTMessageTracker::countEntry() const
@ -185,8 +179,7 @@ void DHTMessageTracker::setRoutingTable
routingTable_ = routingTable; routingTable_ = routingTable;
} }
void DHTMessageTracker::setMessageFactory void DHTMessageTracker::setMessageFactory(DHTMessageFactory* factory)
(const std::shared_ptr<DHTMessageFactory>& factory)
{ {
factory_ = factory; factory_ = factory;
} }

View file

@ -55,38 +55,37 @@ class DHTMessageTrackerEntry;
class DHTMessageTracker { class DHTMessageTracker {
private: private:
std::deque<std::shared_ptr<DHTMessageTrackerEntry> > entries_; std::deque<std::unique_ptr<DHTMessageTrackerEntry>> entries_;
std::shared_ptr<DHTRoutingTable> routingTable_; std::shared_ptr<DHTRoutingTable> routingTable_;
std::shared_ptr<DHTMessageFactory> factory_; DHTMessageFactory* factory_;
public: public:
DHTMessageTracker(); DHTMessageTracker();
~DHTMessageTracker(); void addMessage(DHTMessage* message,
void addMessage(const std::shared_ptr<DHTMessage>& message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()); std::unique_ptr<DHTMessageCallback>{});
std::pair<std::shared_ptr<DHTResponseMessage>, std::shared_ptr<DHTMessageCallback> > std::pair<std::unique_ptr<DHTResponseMessage>,
std::unique_ptr<DHTMessageCallback>>
messageArrived(const Dict* dict, messageArrived(const Dict* dict,
const std::string& ipaddr, uint16_t port); const std::string& ipaddr, uint16_t port);
void handleTimeout(); void handleTimeout();
// Made public so that unnamed functor can access this // Made public so that unnamed functor can access this
void handleTimeoutEntry(const std::shared_ptr<DHTMessageTrackerEntry>& entry); void handleTimeoutEntry(DHTMessageTrackerEntry* entry);
std::shared_ptr<DHTMessageTrackerEntry> getEntryFor // // For unittest only
(const std::shared_ptr<DHTMessage>& message) const; const DHTMessageTrackerEntry* getEntryFor(const DHTMessage* message) const;
size_t countEntry() const; size_t countEntry() const;
void setRoutingTable(const std::shared_ptr<DHTRoutingTable>& routingTable); void setRoutingTable(const std::shared_ptr<DHTRoutingTable>& routingTable);
void setMessageFactory(const std::shared_ptr<DHTMessageFactory>& factory); void setMessageFactory(DHTMessageFactory* factory);
}; };
} // namespace aria2 } // namespace aria2

View file

@ -42,19 +42,20 @@
namespace aria2 { namespace aria2 {
DHTMessageTrackerEntry::DHTMessageTrackerEntry(const std::shared_ptr<DHTMessage>& sentMessage, DHTMessageTrackerEntry::DHTMessageTrackerEntry
time_t timeout, (std::shared_ptr<DHTNode> targetNode,
const std::shared_ptr<DHTMessageCallback>& callback): std::string transactionID,
targetNode_(sentMessage->getRemoteNode()), std::string messageType,
transactionID_(sentMessage->getTransactionID()), time_t timeout,
messageType_(sentMessage->getMessageType()), std::unique_ptr<DHTMessageCallback> callback)
callback_(callback), : targetNode_{std::move(targetNode)},
dispatchedTime_(global::wallclock()), transactionID_{std::move(transactionID)},
timeout_(timeout) messageType_{std::move(messageType)},
callback_{std::move(callback)},
dispatchedTime_{global::wallclock()},
timeout_{timeout}
{} {}
DHTMessageTrackerEntry::~DHTMessageTrackerEntry() {}
bool DHTMessageTrackerEntry::isTimeout() const bool DHTMessageTrackerEntry::isTimeout() const
{ {
return dispatchedTime_.difference(global::wallclock()) >= timeout_; return dispatchedTime_.difference(global::wallclock()) >= timeout_;
@ -84,4 +85,25 @@ int64_t DHTMessageTrackerEntry::getElapsedMillis() const
return dispatchedTime_.differenceInMillis(global::wallclock()); return dispatchedTime_.differenceInMillis(global::wallclock());
} }
const std::shared_ptr<DHTNode>& DHTMessageTrackerEntry::getTargetNode() const
{
return targetNode_;
}
const std::string& DHTMessageTrackerEntry::getMessageType() const
{
return messageType_;
}
const std::unique_ptr<DHTMessageCallback>&
DHTMessageTrackerEntry::getCallback() const
{
return callback_;
}
std::unique_ptr<DHTMessageCallback> DHTMessageTrackerEntry::popCallback()
{
return std::move(callback_);
}
} // namespace aria2 } // namespace aria2

View file

@ -57,18 +57,18 @@ private:
std::string messageType_; std::string messageType_;
std::shared_ptr<DHTMessageCallback> callback_; std::unique_ptr<DHTMessageCallback> callback_;
Timer dispatchedTime_; Timer dispatchedTime_;
time_t timeout_; time_t timeout_;
public: public:
DHTMessageTrackerEntry(const std::shared_ptr<DHTMessage>& sentMessage, DHTMessageTrackerEntry(std::shared_ptr<DHTNode> targetNode,
std::string transactionID,
std::string messageType,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()); std::unique_ptr<DHTMessageCallback>{});
~DHTMessageTrackerEntry();
bool isTimeout() const; bool isTimeout() const;
@ -76,21 +76,10 @@ public:
bool match(const std::string& transactionID, const std::string& ipaddr, uint16_t port) const; bool match(const std::string& transactionID, const std::string& ipaddr, uint16_t port) const;
const std::shared_ptr<DHTNode>& getTargetNode() const const std::shared_ptr<DHTNode>& getTargetNode() const;
{ const std::string& getMessageType() const;
return targetNode_; const std::unique_ptr<DHTMessageCallback>& getCallback() const;
} std::unique_ptr<DHTMessageCallback> popCallback();
const std::string& getMessageType() const
{
return messageType_;
}
const std::shared_ptr<DHTMessageCallback>& getCallback() const
{
return callback_;
}
int64_t getElapsedMillis() const; int64_t getElapsedMillis() const;
}; };

View file

@ -65,6 +65,11 @@ bool DHTNode::operator==(const DHTNode& node) const
return memcmp(id_, node.id_, DHT_ID_LENGTH) == 0; return memcmp(id_, node.id_, DHT_ID_LENGTH) == 0;
} }
bool DHTNode::operator!=(const DHTNode& node) const
{
return !(*this == node);
}
bool DHTNode::operator<(const DHTNode& node) const bool DHTNode::operator<(const DHTNode& node) const
{ {
for(size_t i = 0; i < DHT_ID_LENGTH; ++i) { for(size_t i = 0; i < DHT_ID_LENGTH; ++i) {

View file

@ -115,6 +115,8 @@ public:
bool operator==(const DHTNode& node) const; bool operator==(const DHTNode& node) const;
bool operator!=(const DHTNode& node) const;
bool operator<(const DHTNode& node) const; bool operator<(const DHTNode& node) const;
std::string toString() const; std::string toString() const;

View file

@ -41,6 +41,7 @@
#include "util.h" #include "util.h"
#include "DHTNodeLookupTaskCallback.h" #include "DHTNodeLookupTaskCallback.h"
#include "DHTQueryMessage.h" #include "DHTQueryMessage.h"
#include "DHTFindNodeMessage.h"
namespace aria2 { namespace aria2 {
@ -53,21 +54,19 @@ DHTNodeLookupTask::getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode> >& nodes,
const DHTFindNodeReplyMessage* message) const DHTFindNodeReplyMessage* message)
{ {
const std::vector<std::shared_ptr<DHTNode> >& knodes = auto& knodes = message->getClosestKNodes();
message->getClosestKNodes(); nodes.insert(std::end(nodes), std::begin(knodes), std::end(knodes));
nodes.insert(nodes.end(), knodes.begin(), knodes.end());
} }
std::shared_ptr<DHTMessage> std::unique_ptr<DHTMessage>
DHTNodeLookupTask::createMessage(const std::shared_ptr<DHTNode>& remoteNode) DHTNodeLookupTask::createMessage(const std::shared_ptr<DHTNode>& remoteNode)
{ {
return getMessageFactory()->createFindNodeMessage(remoteNode, getTargetID()); return getMessageFactory()->createFindNodeMessage(remoteNode, getTargetID());
} }
std::shared_ptr<DHTMessageCallback> DHTNodeLookupTask::createCallback() std::unique_ptr<DHTMessageCallback> DHTNodeLookupTask::createCallback()
{ {
return std::shared_ptr<DHTNodeLookupTaskCallback> return make_unique<DHTNodeLookupTaskCallback>(this);
(new DHTNodeLookupTaskCallback(this));
} }
} // namespace aria2 } // namespace aria2

View file

@ -50,10 +50,10 @@ public:
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode> >& nodes,
const DHTFindNodeReplyMessage* message); const DHTFindNodeReplyMessage* message);
virtual std::shared_ptr<DHTMessage> createMessage virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode); (const std::shared_ptr<DHTNode>& remoteNode);
virtual std::shared_ptr<DHTMessageCallback> createCallback(); virtual std::unique_ptr<DHTMessageCallback> createCallback();
}; };
} // namespace aria2 } // namespace aria2

View file

@ -48,6 +48,9 @@
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "DHTPeerLookupTaskCallback.h" #include "DHTPeerLookupTaskCallback.h"
#include "DHTQueryMessage.h" #include "DHTQueryMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "fmt.h" #include "fmt.h"
namespace aria2 { namespace aria2 {
@ -62,12 +65,11 @@ DHTPeerLookupTask::DHTPeerLookupTask
void void
DHTPeerLookupTask::getNodesFromMessage DHTPeerLookupTask::getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode>>& nodes,
const DHTGetPeersReplyMessage* message) const DHTGetPeersReplyMessage* message)
{ {
const std::vector<std::shared_ptr<DHTNode> >& knodes = auto& knodes = message->getClosestKNodes();
message->getClosestKNodes(); nodes.insert(std::end(nodes), std::begin(knodes), std::end(knodes));
nodes.insert(nodes.end(), knodes.begin(), knodes.end());
} }
void DHTPeerLookupTask::onReceivedInternal void DHTPeerLookupTask::onReceivedInternal
@ -81,16 +83,15 @@ void DHTPeerLookupTask::onReceivedInternal
static_cast<unsigned long>(message->getValues().size()))); static_cast<unsigned long>(message->getValues().size())));
} }
std::shared_ptr<DHTMessage> DHTPeerLookupTask::createMessage std::unique_ptr<DHTMessage> DHTPeerLookupTask::createMessage
(const std::shared_ptr<DHTNode>& remoteNode) (const std::shared_ptr<DHTNode>& remoteNode)
{ {
return getMessageFactory()->createGetPeersMessage(remoteNode, getTargetID()); return getMessageFactory()->createGetPeersMessage(remoteNode, getTargetID());
} }
std::shared_ptr<DHTMessageCallback> DHTPeerLookupTask::createCallback() std::unique_ptr<DHTMessageCallback> DHTPeerLookupTask::createCallback()
{ {
return std::shared_ptr<DHTPeerLookupTaskCallback> return make_unique<DHTPeerLookupTaskCallback>(this);
(new DHTPeerLookupTaskCallback(this));
} }
void DHTPeerLookupTask::onFinish() void DHTPeerLookupTask::onFinish()
@ -99,26 +100,24 @@ void DHTPeerLookupTask::onFinish()
util::toHex(getTargetID(), DHT_ID_LENGTH).c_str())); util::toHex(getTargetID(), DHT_ID_LENGTH).c_str()));
// send announce_peer message to K closest nodes // send announce_peer message to K closest nodes
size_t num = DHTBucket::K; size_t num = DHTBucket::K;
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::const_iterator i = for(auto i = std::begin(getEntries()), eoi = std::end(getEntries());
getEntries().begin(), eoi = getEntries().end();
i != eoi && num > 0; ++i) { i != eoi && num > 0; ++i) {
if(!(*i)->used) { if(!(*i)->used) {
continue; continue;
} }
const std::shared_ptr<DHTNode>& node = (*i)->node; auto& node = (*i)->node;
std::string idHex = util::toHex(node->getID(), DHT_ID_LENGTH); std::string idHex = util::toHex(node->getID(), DHT_ID_LENGTH);
std::string token = tokenStorage_[idHex]; std::string token = tokenStorage_[idHex];
if(token.empty()) { if(token.empty()) {
A2_LOG_DEBUG(fmt("Token is empty for ID:%s", idHex.c_str())); A2_LOG_DEBUG(fmt("Token is empty for ID:%s", idHex.c_str()));
continue; continue;
} }
std::shared_ptr<DHTMessage> m = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createAnnouncePeerMessage (getMessageFactory()->createAnnouncePeerMessage
(node, (node,
getTargetID(), // this is infoHash getTargetID(), // this is infoHash
tcpPort_, tcpPort_,
token); token));
getMessageDispatcher()->addMessageToQueue(m);
--num; --num;
} }
} }

View file

@ -58,15 +58,15 @@ public:
uint16_t tcpPort); uint16_t tcpPort);
virtual void getNodesFromMessage virtual void getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes, (std::vector<std::shared_ptr<DHTNode>>& nodes,
const DHTGetPeersReplyMessage* message); const DHTGetPeersReplyMessage* message);
virtual void onReceivedInternal(const DHTGetPeersReplyMessage* message); virtual void onReceivedInternal(const DHTGetPeersReplyMessage* message);
virtual std::shared_ptr<DHTMessage> createMessage virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode); (const std::shared_ptr<DHTNode>& remoteNode);
virtual std::shared_ptr<DHTMessageCallback> createCallback(); virtual std::unique_ptr<DHTMessageCallback> createCallback();
virtual void onFinish(); virtual void onFinish();

View file

@ -38,6 +38,7 @@
#include "DHTMessageDispatcher.h" #include "DHTMessageDispatcher.h"
#include "DHTMessageFactory.h" #include "DHTMessageFactory.h"
#include "DHTMessageCallback.h" #include "DHTMessageCallback.h"
#include "DHTPingReplyMessage.h"
namespace aria2 { namespace aria2 {
@ -48,20 +49,17 @@ DHTPingMessage::DHTPingMessage(const std::shared_ptr<DHTNode>& localNode,
const std::string& transactionID): const std::string& transactionID):
DHTQueryMessage(localNode, remoteNode, transactionID) {} DHTQueryMessage(localNode, remoteNode, transactionID) {}
DHTPingMessage::~DHTPingMessage() {}
void DHTPingMessage::doReceivedAction() void DHTPingMessage::doReceivedAction()
{ {
// send back ping reply // send back ping reply
std::shared_ptr<DHTMessage> reply = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createPingReplyMessage (getMessageFactory()->createPingReplyMessage
(getRemoteNode(), getLocalNode()->getID(), getTransactionID()); (getRemoteNode(), getLocalNode()->getID(), getTransactionID()));
getMessageDispatcher()->addMessageToQueue(reply);
} }
std::shared_ptr<Dict> DHTPingMessage::getArgument() std::shared_ptr<Dict> DHTPingMessage::getArgument()
{ {
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH)); aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
return aDict; return aDict;
} }

View file

@ -46,8 +46,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode, const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL); const std::string& transactionID = A2STR::NIL);
virtual ~DHTPingMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument(); virtual std::shared_ptr<Dict> getArgument();

View file

@ -53,13 +53,11 @@ DHTPingReplyMessage::DHTPingReplyMessage
memcpy(id_, id, DHT_ID_LENGTH); memcpy(id_, id, DHT_ID_LENGTH);
} }
DHTPingReplyMessage::~DHTPingReplyMessage() {}
void DHTPingReplyMessage::doReceivedAction() {} void DHTPingReplyMessage::doReceivedAction() {}
std::shared_ptr<Dict> DHTPingReplyMessage::getResponse() std::shared_ptr<Dict> DHTPingReplyMessage::getResponse()
{ {
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put(DHTMessage::ID, String::g(id_, DHT_ID_LENGTH)); rDict->put(DHTMessage::ID, String::g(id_, DHT_ID_LENGTH));
return rDict; return rDict;
} }

View file

@ -49,8 +49,6 @@ public:
const unsigned char* id, const unsigned char* id,
const std::string& transactionID); const std::string& transactionID);
virtual ~DHTPingReplyMessage();
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse(); virtual std::shared_ptr<Dict> getResponse();

View file

@ -40,6 +40,7 @@
#include "DHTConstants.h" #include "DHTConstants.h"
#include "DHTPingReplyMessageCallback.h" #include "DHTPingReplyMessageCallback.h"
#include "DHTQueryMessage.h" #include "DHTQueryMessage.h"
#include "DHTPingMessage.h"
namespace aria2 { namespace aria2 {
@ -56,11 +57,10 @@ DHTPingTask::~DHTPingTask() {}
void DHTPingTask::addMessage() void DHTPingTask::addMessage()
{ {
std::shared_ptr<DHTMessage> m = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createPingMessage(remoteNode_); (getMessageFactory()->createPingMessage(remoteNode_),
std::shared_ptr<DHTMessageCallback> callback timeout_,
(new DHTPingReplyMessageCallback<DHTPingTask>(this)); make_unique<DHTPingReplyMessageCallback<DHTPingTask>>(this));
getMessageDispatcher()->addMessageToQueue(m, timeout_, callback);
} }
void DHTPingTask::startup() void DHTPingTask::startup()

View file

@ -42,6 +42,7 @@
#include "LogFactory.h" #include "LogFactory.h"
#include "DHTPingReplyMessageCallback.h" #include "DHTPingReplyMessageCallback.h"
#include "DHTQueryMessage.h" #include "DHTQueryMessage.h"
#include "DHTPingMessage.h"
#include "fmt.h" #include "fmt.h"
namespace aria2 { namespace aria2 {
@ -67,11 +68,10 @@ void DHTReplaceNodeTask::sendMessage()
if(!questionableNode) { if(!questionableNode) {
setFinished(true); setFinished(true);
} else { } else {
std::shared_ptr<DHTMessage> m = getMessageDispatcher()->addMessageToQueue
getMessageFactory()->createPingMessage(questionableNode); (getMessageFactory()->createPingMessage(questionableNode),
std::shared_ptr<DHTMessageCallback> callback timeout_,
(new DHTPingReplyMessageCallback<DHTReplaceNodeTask>(this)); make_unique<DHTPingReplyMessageCallback<DHTReplaceNodeTask>>(this));
getMessageDispatcher()->addMessageToQueue(m, timeout_, callback);
} }
} }

View file

@ -61,6 +61,8 @@
#include "DHTRegistry.h" #include "DHTRegistry.h"
#include "DHTBucketRefreshTask.h" #include "DHTBucketRefreshTask.h"
#include "DHTMessageCallback.h" #include "DHTMessageCallback.h"
#include "DHTMessageTrackerEntry.h"
#include "DHTMessageEntry.h"
#include "UDPTrackerClient.h" #include "UDPTrackerClient.h"
#include "BtRegistry.h" #include "BtRegistry.h"
#include "prefs.h" #include "prefs.h"
@ -137,27 +139,19 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
util::toHex(localNode->getID(), DHT_ID_LENGTH).c_str())); util::toHex(localNode->getID(), DHT_ID_LENGTH).c_str()));
std::shared_ptr<DHTRoutingTable> routingTable(new DHTRoutingTable(localNode)); std::shared_ptr<DHTRoutingTable> routingTable(new DHTRoutingTable(localNode));
std::shared_ptr<DHTMessageFactoryImpl> factory auto factory = std::make_shared<DHTMessageFactoryImpl>(family);
(new DHTMessageFactoryImpl(family)); auto tracker = std::make_shared<DHTMessageTracker>();
auto dispatcher = std::make_shared<DHTMessageDispatcherImpl>(tracker);
std::shared_ptr<DHTMessageTracker> tracker(new DHTMessageTracker()); auto receiver = std::make_shared<DHTMessageReceiver>(tracker);
auto taskQueue = std::make_shared<DHTTaskQueueImpl>();
std::shared_ptr<DHTMessageDispatcherImpl> dispatcher(new DHTMessageDispatcherImpl(tracker)); auto taskFactory = std::make_shared<DHTTaskFactoryImpl>();
auto peerAnnounceStorage = std::make_shared<DHTPeerAnnounceStorage>();
std::shared_ptr<DHTMessageReceiver> receiver(new DHTMessageReceiver(tracker)); auto tokenTracker = std::make_shared<DHTTokenTracker>();
const time_t messageTimeout =
std::shared_ptr<DHTTaskQueue> taskQueue(new DHTTaskQueueImpl()); e->getOption()->getAsInt(PREF_DHT_MESSAGE_TIMEOUT);
std::shared_ptr<DHTTaskFactoryImpl> taskFactory(new DHTTaskFactoryImpl());
std::shared_ptr<DHTPeerAnnounceStorage> peerAnnounceStorage(new DHTPeerAnnounceStorage());
std::shared_ptr<DHTTokenTracker> tokenTracker(new DHTTokenTracker());
const time_t messageTimeout = e->getOption()->getAsInt(PREF_DHT_MESSAGE_TIMEOUT);
// wiring up // wiring up
tracker->setRoutingTable(routingTable); tracker->setRoutingTable(routingTable);
tracker->setMessageFactory(factory); tracker->setMessageFactory(factory.get());
dispatcher->setTimeout(messageTimeout); dispatcher->setTimeout(messageTimeout);
@ -186,7 +180,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
factory->setLocalNode(localNode); factory->setLocalNode(localNode);
// For now, UDPTrackerClient was enabled along with DHT // For now, UDPTrackerClient was enabled along with DHT
std::shared_ptr<UDPTrackerClient> udpTrackerClient(new UDPTrackerClient()); auto udpTrackerClient = std::make_shared<UDPTrackerClient>();
// assign them into DHTRegistry // assign them into DHTRegistry
if(family == AF_INET) { if(family == AF_INET) {
DHTRegistry::getMutableData().localNode = localNode; DHTRegistry::getMutableData().localNode = localNode;
@ -211,11 +205,9 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
DHTRegistry::getMutableData6().messageFactory = factory; DHTRegistry::getMutableData6().messageFactory = factory;
} }
// add deserialized nodes to routing table // add deserialized nodes to routing table
const std::vector<std::shared_ptr<DHTNode> >& desnodes = auto& desnodes = deserializer.getNodes();
deserializer.getNodes(); for(auto& node : desnodes) {
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i = routingTable->addNode(node);
desnodes.begin(), eoi = desnodes.end(); i != eoi; ++i) {
routingTable->addNode(*i);
} }
if(!desnodes.empty()) { if(!desnodes.empty()) {
auto task = std::static_pointer_cast<DHTBucketRefreshTask> auto task = std::static_pointer_cast<DHTBucketRefreshTask>
@ -234,7 +226,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
std::pair<std::string, uint16_t> addr std::pair<std::string, uint16_t> addr
(e->getOption()->get(prefEntryPointHost), (e->getOption()->get(prefEntryPointHost),
e->getOption()->getAsInt(prefEntryPointPort)); e->getOption()->getAsInt(prefEntryPointPort));
std::vector<std::pair<std::string, uint16_t> > entryPoints; std::vector<std::pair<std::string, uint16_t>> entryPoints;
entryPoints.push_back(addr); entryPoints.push_back(addr);
auto command = make_unique<DHTEntryPointNameResolveCommand> auto command = make_unique<DHTEntryPointNameResolveCommand>
(e->newCUID(), e, entryPoints); (e->newCUID(), e, entryPoints);
@ -302,7 +294,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
if(family == AF_INET) { if(family == AF_INET) {
DHTRegistry::clearData(); DHTRegistry::clearData();
e->getBtRegistry()->setUDPTrackerClient e->getBtRegistry()->setUDPTrackerClient
(std::shared_ptr<UDPTrackerClient>()); (std::shared_ptr<UDPTrackerClient>{});
} else { } else {
DHTRegistry::clearData6(); DHTRegistry::clearData6();
} }

View file

@ -20,7 +20,14 @@ class DHTAnnouncePeerMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction); CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
void setUp() {} std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {} void tearDown() {}
@ -28,13 +35,12 @@ public:
void testDoReceivedAction(); void testDoReceivedAction();
class MockDHTMessageFactory2:public MockDHTMessageFactory { class MockDHTMessageFactory2:public MockDHTMessageFactory {
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID) const std::string& transactionID) override
{ {
return std::shared_ptr<DHTResponseMessage> return make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
(new MockDHTResponseMessage transactionID);
(localNode_, remoteNode, "announce_peer", transactionID));
} }
}; };
}; };
@ -44,9 +50,6 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTAnnouncePeerMessageTest);
void DHTAnnouncePeerMessageTest::testGetBencodedMessage() void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
@ -57,7 +60,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
std::string token = "token"; std::string token = "token";
uint16_t port = 6881; uint16_t port = 6881;
DHTAnnouncePeerMessage msg(localNode, remoteNode, infoHash, port, token, transactionID); DHTAnnouncePeerMessage msg(localNode_, remoteNode_, infoHash, port, token,
transactionID);
msg.setVersion("A200"); msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage(); std::string msgbody = msg.getBencodedMessage();
@ -66,8 +70,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
dict.put("v", "A200"); dict.put("v", "A200");
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "announce_peer"); dict.put("q", "announce_peer");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH)); aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH)); aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
aDict->put("port", Integer::g(port)); aDict->put("port", Integer::g(port));
aDict->put("token", token); aDict->put("token", token);
@ -79,10 +83,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
void DHTAnnouncePeerMessageTest::testDoReceivedAction() void DHTAnnouncePeerMessageTest::testDoReceivedAction()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode()); remoteNode_->setIPAddress("192.168.0.1");
std::shared_ptr<DHTNode> remoteNode(new DHTNode()); remoteNode_->setPort(6881);
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
@ -96,10 +98,11 @@ void DHTAnnouncePeerMessageTest::testDoReceivedAction()
DHTPeerAnnounceStorage peerAnnounceStorage; DHTPeerAnnounceStorage peerAnnounceStorage;
MockDHTMessageFactory2 factory; MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode); factory.setLocalNode(localNode_);
MockDHTMessageDispatcher dispatcher; MockDHTMessageDispatcher dispatcher;
DHTAnnouncePeerMessage msg(localNode, remoteNode, infoHash, port, token, transactionID); DHTAnnouncePeerMessage msg(localNode_, remoteNode_, infoHash, port, token,
transactionID);
msg.setPeerAnnounceStorage(&peerAnnounceStorage); msg.setPeerAnnounceStorage(&peerAnnounceStorage);
msg.setMessageFactory(&factory); msg.setMessageFactory(&factory);
msg.setMessageDispatcher(&dispatcher); msg.setMessageDispatcher(&dispatcher);
@ -107,10 +110,10 @@ void DHTAnnouncePeerMessageTest::testDoReceivedAction()
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage> auto m = dynamic_cast<DHTAnnouncePeerReplyMessage*>
(dispatcher.messageQueue_[0].message_); (dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("announce_peer"), m->getMessageType()); CPPUNIT_ASSERT_EQUAL(std::string("announce_peer"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(transactionID, m->getTransactionID()); CPPUNIT_ASSERT_EQUAL(transactionID, m->getTransactionID());
std::vector<std::shared_ptr<Peer> > peers; std::vector<std::shared_ptr<Peer> > peers;

View file

@ -20,7 +20,14 @@ class DHTFindNodeMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction); CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
void setUp() {} std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {} void tearDown() {}
@ -29,16 +36,15 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory { class MockDHTMessageFactory2:public MockDHTMessageFactory {
public: public:
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) const std::string& transactionID) override
{ {
std::shared_ptr<MockDHTResponseMessage> m auto m = make_unique<DHTFindNodeReplyMessage>
(new MockDHTResponseMessage (AF_INET, localNode_, remoteNode, transactionID);
(localNode_, remoteNode, "find_node", transactionID)); m->setClosestKNodes(std::move(closestKNodes));
m->nodes_ = closestKNodes;
return m; return m;
} }
}; };
@ -49,16 +55,14 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTFindNodeMessageTest);
void DHTFindNodeMessageTest::testGetBencodedMessage() void DHTFindNodeMessageTest::testGetBencodedMessage()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
std::shared_ptr<DHTNode> targetNode(new DHTNode()); auto targetNode = std::make_shared<DHTNode>();
DHTFindNodeMessage msg(localNode, remoteNode, targetNode->getID(), transactionID); DHTFindNodeMessage msg(localNode_, remoteNode_, targetNode->getID(),
transactionID);
msg.setVersion("A200"); msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage(); std::string msgbody = msg.getBencodedMessage();
@ -67,8 +71,8 @@ void DHTFindNodeMessageTest::testGetBencodedMessage()
dict.put("v", "A200"); dict.put("v", "A200");
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "find_node"); dict.put("q", "find_node");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH)); aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("target", String::g(targetNode->getID(), DHT_ID_LENGTH)); aDict->put("target", String::g(targetNode->getID(), DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
@ -77,22 +81,20 @@ void DHTFindNodeMessageTest::testGetBencodedMessage()
void DHTFindNodeMessageTest::testDoReceivedAction() void DHTFindNodeMessageTest::testDoReceivedAction()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
std::shared_ptr<DHTNode> targetNode(new DHTNode()); auto targetNode = std::make_shared<DHTNode>();
MockDHTMessageDispatcher dispatcher; MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory; MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode); factory.setLocalNode(localNode_);
DHTRoutingTable routingTable(localNode); DHTRoutingTable routingTable(localNode_);
routingTable.addNode(targetNode); routingTable.addNode(targetNode);
DHTFindNodeMessage msg(localNode, remoteNode, targetNode->getID(), transactionID); DHTFindNodeMessage msg(localNode_, remoteNode_, targetNode->getID(),
transactionID);
msg.setMessageDispatcher(&dispatcher); msg.setMessageDispatcher(&dispatcher);
msg.setMessageFactory(&factory); msg.setMessageFactory(&factory);
msg.setRoutingTable(&routingTable); msg.setRoutingTable(&routingTable);
@ -100,13 +102,13 @@ void DHTFindNodeMessageTest::testDoReceivedAction()
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage> auto m = dynamic_cast<DHTFindNodeReplyMessage*>
(dispatcher.messageQueue_[0].message_); (dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("find_node"), m->getMessageType()); CPPUNIT_ASSERT_EQUAL(std::string("find_node"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID()); CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL((size_t)1, m->nodes_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, m->getClosestKNodes().size());
} }
} // namespace aria2 } // namespace aria2

View file

@ -22,7 +22,14 @@ class DHTGetPeersMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction); CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
void setUp() {} std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {} void tearDown() {}
@ -31,20 +38,19 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory { class MockDHTMessageFactory2:public MockDHTMessageFactory {
public: public:
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers, std::vector<std::shared_ptr<Peer>> peers,
const std::string& token, const std::string& token,
const std::string& transactionID) const std::string& transactionID) override
{ {
std::shared_ptr<MockDHTResponseMessage> m auto m = make_unique<DHTGetPeersReplyMessage>(AF_INET, localNode_,
(new MockDHTResponseMessage remoteNode, token,
(localNode_, remoteNode, "get_peers", transactionID)); transactionID);
m->nodes_ = closestKNodes; m->setClosestKNodes(closestKNodes);
m->peers_ = peers; m->setValues(peers);
m->token_ = token;
return m; return m;
} }
}; };
@ -55,9 +61,6 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTGetPeersMessageTest);
void DHTGetPeersMessageTest::testGetBencodedMessage() void DHTGetPeersMessageTest::testGetBencodedMessage()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
@ -65,7 +68,7 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
unsigned char infoHash[DHT_ID_LENGTH]; unsigned char infoHash[DHT_ID_LENGTH];
util::generateRandomData(infoHash, DHT_ID_LENGTH); util::generateRandomData(infoHash, DHT_ID_LENGTH);
DHTGetPeersMessage msg(localNode, remoteNode, infoHash, transactionID); DHTGetPeersMessage msg(localNode_, remoteNode_, infoHash, transactionID);
msg.setVersion("A200"); msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage(); std::string msgbody = msg.getBencodedMessage();
@ -75,8 +78,8 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
dict.put("v", "A200"); dict.put("v", "A200");
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "get_peers"); dict.put("q", "get_peers");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH)); aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH)); aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
@ -86,10 +89,8 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
void DHTGetPeersMessageTest::testDoReceivedAction() void DHTGetPeersMessageTest::testDoReceivedAction()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode()); remoteNode_->setIPAddress("192.168.0.1");
std::shared_ptr<DHTNode> remoteNode(new DHTNode()); remoteNode_->setPort(6881);
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
@ -101,10 +102,10 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
DHTTokenTracker tokenTracker; DHTTokenTracker tokenTracker;
MockDHTMessageDispatcher dispatcher; MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory; MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode); factory.setLocalNode(localNode_);
DHTRoutingTable routingTable(localNode); DHTRoutingTable routingTable(localNode_);
DHTGetPeersMessage msg(localNode, remoteNode, infoHash, transactionID); DHTGetPeersMessage msg(localNode_, remoteNode_, infoHash, transactionID);
msg.setRoutingTable(&routingTable); msg.setRoutingTable(&routingTable);
msg.setTokenTracker(&tokenTracker); msg.setTokenTracker(&tokenTracker);
msg.setMessageDispatcher(&dispatcher); msg.setMessageDispatcher(&dispatcher);
@ -120,22 +121,25 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage> auto m = dynamic_cast<DHTGetPeersReplyMessage*>
(dispatcher.messageQueue_[0].message_); (dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType()); CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID()); CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken(infoHash, remoteNode->getIPAddress(), remoteNode->getPort()), m->token_); CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken
CPPUNIT_ASSERT_EQUAL((size_t)0, m->nodes_.size()); (infoHash, remoteNode_->getIPAddress(),
CPPUNIT_ASSERT_EQUAL((size_t)2, m->peers_.size()); remoteNode_->getPort()),
m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)0, m->getClosestKNodes().size());
CPPUNIT_ASSERT_EQUAL((size_t)2, m->getValues().size());
{ {
std::shared_ptr<Peer> peer = m->peers_[0]; auto peer = m->getValues()[0];
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.100"), peer->getIPAddress()); CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.100"), peer->getIPAddress());
CPPUNIT_ASSERT_EQUAL((uint16_t)6888, peer->getPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6888, peer->getPort());
} }
{ {
std::shared_ptr<Peer> peer = m->peers_[1]; auto peer = m->getValues()[1];
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.101"), peer->getIPAddress()); CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.101"), peer->getIPAddress());
CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort());
} }
@ -144,7 +148,7 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
{ {
// localhost doesn't have peer contact information for that infohash. // localhost doesn't have peer contact information for that infohash.
DHTPeerAnnounceStorage peerAnnounceStorage; DHTPeerAnnounceStorage peerAnnounceStorage;
DHTRoutingTable routingTable(localNode); DHTRoutingTable routingTable(localNode_);
std::shared_ptr<DHTNode> returnNode1(new DHTNode()); std::shared_ptr<DHTNode> returnNode1(new DHTNode());
routingTable.addNode(returnNode1); routingTable.addNode(returnNode1);
@ -154,16 +158,19 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage> auto m = dynamic_cast<DHTGetPeersReplyMessage*>
(dispatcher.messageQueue_[0].message_); (dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType()); CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID()); CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken(infoHash, remoteNode->getIPAddress(), remoteNode->getPort()), m->token_); CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken
CPPUNIT_ASSERT_EQUAL((size_t)1, m->nodes_.size()); (infoHash, remoteNode_->getIPAddress(),
CPPUNIT_ASSERT(*returnNode1 == *m->nodes_[0]); remoteNode_->getPort()),
CPPUNIT_ASSERT_EQUAL((size_t)0, m->peers_.size()); m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*returnNode1 == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT_EQUAL((size_t)0, m->getValues().size());
} }
} }

View file

@ -1,12 +1,15 @@
#include "DHTNode.h"
#include "DHTNodeLookupEntry.h"
#include "DHTIDCloser.h" #include "DHTIDCloser.h"
#include "Exception.h"
#include "util.h"
#include <cstring> #include <cstring>
#include <algorithm> #include <algorithm>
#include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/HelperMacros.h>
#include "DHTNode.h"
#include "DHTNodeLookupEntry.h"
#include "Exception.h"
#include "util.h"
namespace aria2 { namespace aria2 {
class DHTIDCloserTest:public CppUnit::TestFixture { class DHTIDCloserTest:public CppUnit::TestFixture {
@ -30,39 +33,40 @@ void DHTIDCloserTest::testOperator()
unsigned char id[DHT_ID_LENGTH]; unsigned char id[DHT_ID_LENGTH];
memset(id, 0xf0, DHT_ID_LENGTH); memset(id, 0xf0, DHT_ID_LENGTH);
std::shared_ptr<DHTNodeLookupEntry> e1 auto e1 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id)))); auto ep1 = e1.get();
id[0] = 0xb0; id[0] = 0xb0;
std::shared_ptr<DHTNodeLookupEntry> e2 auto e2 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id)))); auto ep2 = e2.get();
id[0] = 0xa0; id[0] = 0xa0;
std::shared_ptr<DHTNodeLookupEntry> e3 auto e3 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id)))); auto ep3 = e3.get();
id[0] = 0x80; id[0] = 0x80;
std::shared_ptr<DHTNodeLookupEntry> e4 auto e4 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id)))); auto ep4 = e4.get();
id[0] = 0x00; id[0] = 0x00;
std::shared_ptr<DHTNodeLookupEntry> e5 auto e5 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id)))); auto ep5 = e5.get();
std::deque<std::shared_ptr<DHTNodeLookupEntry> > entries; auto entries = std::vector<std::unique_ptr<DHTNodeLookupEntry>>{};
entries.push_back(e1); entries.push_back(std::move(e1));
entries.push_back(e2); entries.push_back(std::move(e2));
entries.push_back(e3); entries.push_back(std::move(e3));
entries.push_back(e4); entries.push_back(std::move(e4));
entries.push_back(e5); entries.push_back(std::move(e5));
std::sort(entries.begin(), entries.end(), DHTIDCloser(e3->node->getID())); std::sort(std::begin(entries), std::end(entries),
DHTIDCloser(ep3->node->getID()));
CPPUNIT_ASSERT(*e3 == *entries[0]); CPPUNIT_ASSERT(*ep3 == *entries[0]);
CPPUNIT_ASSERT(*e2 == *entries[1]); CPPUNIT_ASSERT(*ep2 == *entries[1]);
CPPUNIT_ASSERT(*e4 == *entries[2]); CPPUNIT_ASSERT(*ep4 == *entries[2]);
CPPUNIT_ASSERT(*e1 == *entries[3]); CPPUNIT_ASSERT(*ep1 == *entries[3]);
CPPUNIT_ASSERT(*e5 == *entries[4]); CPPUNIT_ASSERT(*ep5 == *entries[4]);
} }
} // namespace aria2 } // namespace aria2

View file

@ -40,25 +40,36 @@ class DHTMessageFactoryImplTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testReceivedErrorMessage); CPPUNIT_TEST(testReceivedErrorMessage);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
std::shared_ptr<DHTMessageFactoryImpl> factory; std::unique_ptr<DHTMessageFactoryImpl> factory;
std::shared_ptr<DHTRoutingTable> routingTable; std::unique_ptr<DHTRoutingTable> routingTable;
std::shared_ptr<DHTNode> localNode; std::shared_ptr<DHTNode> localNode;
std::unique_ptr<DHTNode> remoteNode_;
std::unique_ptr<DHTNode> remoteNode6_;
unsigned char transactionID[DHT_TRANSACTION_ID_LENGTH]; unsigned char transactionID[DHT_TRANSACTION_ID_LENGTH];
unsigned char remoteNodeID[DHT_ID_LENGTH]; unsigned char remoteNodeID[DHT_ID_LENGTH];
void setUp() void setUp()
{ {
localNode.reset(new DHTNode()); localNode = std::make_shared<DHTNode>();
factory.reset(new DHTMessageFactoryImpl(AF_INET)); factory = make_unique<DHTMessageFactoryImpl>(AF_INET);
factory->setLocalNode(localNode); factory->setLocalNode(localNode);
memset(transactionID, 0xff, DHT_TRANSACTION_ID_LENGTH); memset(transactionID, 0xff, DHT_TRANSACTION_ID_LENGTH);
memset(remoteNodeID, 0x0f, DHT_ID_LENGTH); memset(remoteNodeID, 0x0f, DHT_ID_LENGTH);
routingTable.reset(new DHTRoutingTable(localNode)); routingTable = make_unique<DHTRoutingTable>(localNode);
factory->setRoutingTable(routingTable.get()); factory->setRoutingTable(routingTable.get());
remoteNode_ = make_unique<DHTNode>(remoteNodeID);
remoteNode_->setIPAddress("192.168.0.1");
remoteNode_->setPort(6881);
remoteNode6_ = make_unique<DHTNode>(remoteNodeID);
remoteNode6_->setIPAddress("2001::2001");
remoteNode6_->setPort(6881);
} }
void tearDown() {} void tearDown() {}
@ -85,18 +96,15 @@ void DHTMessageFactoryImplTest::testCreatePingMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "ping"); dict.put("q", "ping");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTPingMessage> auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
(factory->createQueryMessage(&dict, "192.168.0.1", 6881)); auto m = dynamic_cast<DHTPingMessage*>(r.get());
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
} }
@ -106,21 +114,17 @@ void DHTMessageFactoryImplTest::testCreatePingReplyMessage()
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("ping", &dict,
remoteNode->setIPAddress("192.168.0.1"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTPingReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTPingReplyMessage>
(factory->createResponseMessage("ping", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
} }
@ -131,21 +135,18 @@ void DHTMessageFactoryImplTest::testCreateFindNodeMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "find_node"); dict.put("q", "find_node");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char targetNodeID[DHT_ID_LENGTH]; unsigned char targetNodeID[DHT_ID_LENGTH];
memset(targetNodeID, 0x11, DHT_ID_LENGTH); memset(targetNodeID, 0x11, DHT_ID_LENGTH);
aDict->put("target", String::g(targetNodeID, DHT_ID_LENGTH)); aDict->put("target", String::g(targetNodeID, DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTFindNodeMessage> auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
(factory->createQueryMessage(&dict, "192.168.0.1", 6881)); auto m = dynamic_cast<DHTFindNodeMessage*>(r.get());
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
CPPUNIT_ASSERT_EQUAL(util::toHex(targetNodeID, DHT_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(targetNodeID, DHT_ID_LENGTH),
@ -158,12 +159,12 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo; std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8]; std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) { for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode()); nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1)); nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1));
nodes[i]->setPort(6881+i); nodes[i]->setPort(6881+i);
@ -179,17 +180,13 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
rDict->put("nodes", compactNodeInfo); rDict->put("nodes", compactNodeInfo);
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("find_node", &dict,
remoteNode->setIPAddress("192.168.0.1"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTFindNodeReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTFindNodeReplyMessage>
(factory->createResponseMessage("find_node", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size()); CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]); CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]); CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]);
@ -202,19 +199,19 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage6() void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage6()
{ {
factory.reset(new DHTMessageFactoryImpl(AF_INET6)); factory = make_unique<DHTMessageFactoryImpl>(AF_INET6);
factory->setLocalNode(localNode); factory->setLocalNode(localNode);
factory->setRoutingTable(routingTable.get()); factory->setRoutingTable(routingTable.get());
try { try {
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo; std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8]; std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) { for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode()); nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("2001::000"+util::uitos(i+1)); nodes[i]->setIPAddress("2001::000"+util::uitos(i+1));
nodes[i]->setPort(6881+i); nodes[i]->setPort(6881+i);
@ -230,17 +227,13 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage6()
rDict->put("nodes6", compactNodeInfo); rDict->put("nodes6", compactNodeInfo);
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("find_node", &dict,
remoteNode->setIPAddress("2001::2001"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTFindNodeReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTFindNodeReplyMessage>
(factory->createResponseMessage("find_node", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size()); CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]); CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]); CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]);
@ -257,21 +250,18 @@ void DHTMessageFactoryImplTest::testCreateGetPeersMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "get_peers"); dict.put("q", "get_peers");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char infoHash[DHT_ID_LENGTH]; unsigned char infoHash[DHT_ID_LENGTH];
memset(infoHash, 0x11, DHT_ID_LENGTH); memset(infoHash, 0x11, DHT_ID_LENGTH);
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH)); aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTGetPeersMessage> auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
(factory->createQueryMessage(&dict, "192.168.0.1", 6881)); auto m = dynamic_cast<DHTGetPeersMessage*>(r.get());
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
CPPUNIT_ASSERT_EQUAL(util::toHex(infoHash, DHT_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(infoHash, DHT_ID_LENGTH),
@ -284,12 +274,12 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo; std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8]; std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) { for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode()); nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1)); nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1));
nodes[i]->setPort(6881+i); nodes[i]->setPort(6881+i);
@ -307,7 +297,8 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
std::deque<std::shared_ptr<Peer> > peers; std::deque<std::shared_ptr<Peer> > peers;
std::shared_ptr<List> valuesList = List::g(); std::shared_ptr<List> valuesList = List::g();
for(size_t i = 0; i < 4; ++i) { for(size_t i = 0; i < 4; ++i) {
std::shared_ptr<Peer> peer(new Peer("192.168.0."+util::uitos(i+1), 6881+i)); auto peer = std::make_shared<Peer>("192.168.0."+util::uitos(i+1),
6881+i);
unsigned char buffer[COMPACT_LEN_IPV6]; unsigned char buffer[COMPACT_LEN_IPV6];
CPPUNIT_ASSERT_EQUAL CPPUNIT_ASSERT_EQUAL
(COMPACT_LEN_IPV4, (COMPACT_LEN_IPV4,
@ -321,17 +312,13 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
rDict->put("token", "token"); rDict->put("token", "token");
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("get_peers", &dict,
remoteNode->setIPAddress("192.168.0.1"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTGetPeersReplyMessage>
(factory->createResponseMessage("get_peers", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken()); CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size()); CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]); CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
@ -351,19 +338,19 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6() void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
{ {
factory.reset(new DHTMessageFactoryImpl(AF_INET6)); factory = make_unique<DHTMessageFactoryImpl>(AF_INET6);
factory->setLocalNode(localNode); factory->setLocalNode(localNode);
factory->setRoutingTable(routingTable.get()); factory->setRoutingTable(routingTable.get());
try { try {
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo; std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8]; std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) { for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode()); nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("2001::000"+util::uitos(i+1)); nodes[i]->setIPAddress("2001::000"+util::uitos(i+1));
nodes[i]->setPort(6881+i); nodes[i]->setPort(6881+i);
@ -378,10 +365,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
} }
rDict->put("nodes6", compactNodeInfo); rDict->put("nodes6", compactNodeInfo);
std::deque<std::shared_ptr<Peer> > peers; std::deque<std::shared_ptr<Peer>> peers;
std::shared_ptr<List> valuesList = List::g(); auto valuesList = List::g();
for(size_t i = 0; i < 4; ++i) { for(size_t i = 0; i < 4; ++i) {
std::shared_ptr<Peer> peer(new Peer("2001::100"+util::uitos(i+1), 6881+i)); auto peer = std::make_shared<Peer>("2001::100"+util::uitos(i+1), 6881+i);
unsigned char buffer[COMPACT_LEN_IPV6]; unsigned char buffer[COMPACT_LEN_IPV6];
CPPUNIT_ASSERT_EQUAL CPPUNIT_ASSERT_EQUAL
(COMPACT_LEN_IPV6, (COMPACT_LEN_IPV6,
@ -395,17 +382,13 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
rDict->put("token", "token"); rDict->put("token", "token");
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("get_peers", &dict,
remoteNode->setIPAddress("2001::2001"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTGetPeersReplyMessage>
(factory->createResponseMessage("get_peers", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken()); CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size()); CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]); CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
@ -430,7 +413,7 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "announce_peer"); dict.put("q", "announce_peer");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char infoHash[DHT_ID_LENGTH]; unsigned char infoHash[DHT_ID_LENGTH];
memset(infoHash, 0x11, DHT_ID_LENGTH); memset(infoHash, 0x11, DHT_ID_LENGTH);
@ -441,14 +424,13 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerMessage()
aDict->put("token", token); aDict->put("token", token);
dict.put("a", aDict); dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTAnnouncePeerMessage> remoteNode_->setPort(6882);
(factory->createQueryMessage(&dict, "192.168.0.1", 6882));
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6882);
remoteNode->setIPAddress("192.168.0.1"); auto m = dynamic_cast<DHTAnnouncePeerMessage*>(r.get());
remoteNode->setPort(6882);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(token, m->getToken()); CPPUNIT_ASSERT_EQUAL(token, m->getToken());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
@ -465,21 +447,17 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerReplyMessage()
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r"); dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g(); auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH)); rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("r", rDict); dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID)); auto r = factory->createResponseMessage("announce_peer", &dict,
remoteNode->setIPAddress("192.168.0.1"); remoteNode_->getIPAddress(),
remoteNode->setPort(6881); remoteNode_->getPort());
auto m = dynamic_cast<DHTAnnouncePeerReplyMessage*>(r.get());
auto m = std::dynamic_pointer_cast<DHTAnnouncePeerReplyMessage>
(factory->createResponseMessage("announce_peer", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH), CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID())); util::toHex(m->getTransactionID()));
} }
@ -489,19 +467,15 @@ void DHTMessageFactoryImplTest::testReceivedErrorMessage()
Dict dict; Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH)); dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "e"); dict.put("y", "e");
std::shared_ptr<List> list = List::g(); auto list = List::g();
list->append(Integer::g(404)); list->append(Integer::g(404));
list->append("Not found"); list->append("Not found");
dict.put("e", list); dict.put("e", list);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
try { try {
factory->createResponseMessage("announce_peer", &dict, factory->createResponseMessage("announce_peer", &dict,
remoteNode->getIPAddress(), remoteNode_->getIPAddress(),
remoteNode->getPort()); remoteNode_->getPort());
CPPUNIT_FAIL("exception must be thrown."); CPPUNIT_FAIL("exception must be thrown.");
} catch(RecoverableException& e) { } catch(RecoverableException& e) {
std::cerr << e.stackTrace() << std::endl; std::cerr << e.stackTrace() << std::endl;

View file

@ -33,14 +33,17 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTMessageTrackerEntryTest);
void DHTMessageTrackerEntryTest::testMatch() void DHTMessageTrackerEntryTest::testMatch()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode()); auto localNode = std::make_shared<DHTNode>();
try { try {
std::shared_ptr<DHTNode> node1(new DHTNode()); auto node1 = std::make_shared<DHTNode>();
std::shared_ptr<MockDHTMessage> msg1(new MockDHTMessage(localNode, node1)); auto msg1 = make_unique<MockDHTMessage>(localNode, node1);
std::shared_ptr<DHTNode> node2(new DHTNode()); auto node2 = std::make_shared<DHTNode>();
std::shared_ptr<MockDHTMessage> msg2(new MockDHTMessage(localNode, node2)); auto msg2 = make_unique<MockDHTMessage>(localNode, node2);
DHTMessageTrackerEntry entry(msg1, 30); DHTMessageTrackerEntry entry(msg1->getRemoteNode(),
msg1->getTransactionID(),
msg1->getMessageType(),
30);
CPPUNIT_ASSERT(entry.match(msg1->getTransactionID(), CPPUNIT_ASSERT(entry.match(msg1->getTransactionID(),
msg1->getRemoteNode()->getIPAddress(), msg1->getRemoteNode()->getIPAddress(),

View file

@ -34,65 +34,62 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTMessageTrackerTest);
void DHTMessageTrackerTest::testMessageArrived() void DHTMessageTrackerTest::testMessageArrived()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode()); auto localNode = std::make_shared<DHTNode>();
std::shared_ptr<DHTRoutingTable> routingTable(new DHTRoutingTable(localNode)); auto routingTable = std::make_shared<DHTRoutingTable>(localNode);
std::shared_ptr<MockDHTMessageFactory> factory(new MockDHTMessageFactory()); auto factory = std::make_shared<MockDHTMessageFactory>();
factory->setLocalNode(localNode); factory->setLocalNode(localNode);
std::shared_ptr<MockDHTMessage> m1(new MockDHTMessage(localNode, auto r1 = std::make_shared<DHTNode>();
std::shared_ptr<DHTNode>(new DHTNode()))); r1->setIPAddress("192.168.0.1");
std::shared_ptr<MockDHTMessage> m2(new MockDHTMessage(localNode, r1->setPort(6881);
std::shared_ptr<DHTNode>(new DHTNode()))); auto r2 = std::make_shared<DHTNode>();
std::shared_ptr<MockDHTMessage> m3(new MockDHTMessage(localNode, r2->setIPAddress("192.168.0.2");
std::shared_ptr<DHTNode>(new DHTNode()))); r2->setPort(6882);
auto r3 = std::make_shared<DHTNode>();
r3->setIPAddress("192.168.0.3");
r3->setPort(6883);
m1->getRemoteNode()->setIPAddress("192.168.0.1"); auto m1 = make_unique<MockDHTMessage>(localNode, r1);
m1->getRemoteNode()->setPort(6881); auto m2 = make_unique<MockDHTMessage>(localNode, r2);
m2->getRemoteNode()->setIPAddress("192.168.0.2"); auto m3 = make_unique<MockDHTMessage>(localNode, r3);
m2->getRemoteNode()->setPort(6882);
m3->getRemoteNode()->setIPAddress("192.168.0.3");
m3->getRemoteNode()->setPort(6883);
DHTMessageTracker tracker; DHTMessageTracker tracker;
tracker.setRoutingTable(routingTable); tracker.setRoutingTable(routingTable);
tracker.setMessageFactory(factory); tracker.setMessageFactory(factory.get());
tracker.addMessage(m1, DHT_MESSAGE_TIMEOUT); tracker.addMessage(m1.get(), DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m2, DHT_MESSAGE_TIMEOUT); tracker.addMessage(m2.get(), DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m3, DHT_MESSAGE_TIMEOUT); tracker.addMessage(m3.get(), DHT_MESSAGE_TIMEOUT);
{ {
Dict resDict; Dict resDict;
resDict.put("t", m2->getTransactionID()); resDict.put("t", m2->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p = auto p =
tracker.messageArrived(&resDict, m2->getRemoteNode()->getIPAddress(), tracker.messageArrived(&resDict, r2->getIPAddress(), r2->getPort());
m2->getRemoteNode()->getPort()); auto& reply = p.first;
std::shared_ptr<DHTMessage> reply = p.first;
CPPUNIT_ASSERT(reply); CPPUNIT_ASSERT(reply);
CPPUNIT_ASSERT(!tracker.getEntryFor(m2)); CPPUNIT_ASSERT(!tracker.getEntryFor(m2.get()));
CPPUNIT_ASSERT_EQUAL((size_t)2, tracker.countEntry()); CPPUNIT_ASSERT_EQUAL((size_t)2, tracker.countEntry());
} }
{ {
Dict resDict; Dict resDict;
resDict.put("t", m3->getTransactionID()); resDict.put("t", m3->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p = auto p =
tracker.messageArrived(&resDict, m3->getRemoteNode()->getIPAddress(), tracker.messageArrived(&resDict, r3->getIPAddress(), r3->getPort());
m3->getRemoteNode()->getPort()); auto& reply = p.first;
std::shared_ptr<DHTMessage> reply = p.first;
CPPUNIT_ASSERT(reply); CPPUNIT_ASSERT(reply);
CPPUNIT_ASSERT(!tracker.getEntryFor(m3)); CPPUNIT_ASSERT(!tracker.getEntryFor(m3.get()));
CPPUNIT_ASSERT_EQUAL((size_t)1, tracker.countEntry()); CPPUNIT_ASSERT_EQUAL((size_t)1, tracker.countEntry());
} }
{ {
Dict resDict; Dict resDict;
resDict.put("t", m1->getTransactionID()); resDict.put("t", m1->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p = auto p = tracker.messageArrived(&resDict, "192.168.1.100", 6889);
tracker.messageArrived(&resDict, "192.168.1.100", 6889); auto& reply = p.first;
std::shared_ptr<DHTMessage> reply = p.first;
CPPUNIT_ASSERT(!reply); CPPUNIT_ASSERT(!reply);
} }

View file

@ -19,7 +19,14 @@ class DHTPingMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction); CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
void setUp() {} std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {} void tearDown() {}
@ -28,14 +35,15 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory { class MockDHTMessageFactory2:public MockDHTMessageFactory {
public: public:
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* remoteNodeID, const unsigned char* remoteNodeID,
const std::string& transactionID) const std::string& transactionID) override
{ {
return std::shared_ptr<MockDHTResponseMessage> unsigned char id[DHT_ID_LENGTH];
(new MockDHTResponseMessage(localNode_, remoteNode, "ping_reply", std::fill(std::begin(id), std::end(id), '0');
transactionID)); return make_unique<DHTPingReplyMessage>
(localNode_, remoteNode, id, transactionID);
} }
}; };
}; };
@ -45,14 +53,11 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTPingMessageTest);
void DHTPingMessageTest::testGetBencodedMessage() void DHTPingMessageTest::testGetBencodedMessage()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
DHTPingMessage msg(localNode, remoteNode, transactionID); DHTPingMessage msg(localNode_, remoteNode_, transactionID);
msg.setVersion("A200"); msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage(); std::string msgbody = msg.getBencodedMessage();
@ -62,8 +67,8 @@ void DHTPingMessageTest::testGetBencodedMessage()
dict.put("v", "A200"); dict.put("v", "A200");
dict.put("y", "q"); dict.put("y", "q");
dict.put("q", "ping"); dict.put("q", "ping");
std::shared_ptr<Dict> aDict = Dict::g(); auto aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH)); aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
dict.put("a", aDict); dict.put("a", aDict);
CPPUNIT_ASSERT_EQUAL(bencode2::encode(&dict), msgbody); CPPUNIT_ASSERT_EQUAL(bencode2::encode(&dict), msgbody);
@ -71,29 +76,26 @@ void DHTPingMessageTest::testGetBencodedMessage()
void DHTPingMessageTest::testDoReceivedAction() void DHTPingMessageTest::testDoReceivedAction()
{ {
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH]; unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH); util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]); std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
MockDHTMessageDispatcher dispatcher; MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory; MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode); factory.setLocalNode(localNode_);
DHTPingMessage msg(localNode, remoteNode, transactionID); DHTPingMessage msg(localNode_, remoteNode_, transactionID);
msg.setMessageDispatcher(&dispatcher); msg.setMessageDispatcher(&dispatcher);
msg.setMessageFactory(&factory); msg.setMessageFactory(&factory);
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage> auto m = dynamic_cast<DHTPingReplyMessage*>
(dispatcher.messageQueue_[0].message_); (dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode()); CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode()); CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("ping_reply"), m->getMessageType()); CPPUNIT_ASSERT_EQUAL(std::string("ping"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID()); CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
} }

View file

@ -10,17 +10,17 @@ namespace aria2 {
class MockDHTMessageDispatcher:public DHTMessageDispatcher { class MockDHTMessageDispatcher:public DHTMessageDispatcher {
public: public:
class Entry { struct Entry {
public: std::unique_ptr<DHTMessage> message_;
std::shared_ptr<DHTMessage> message_;
time_t timeout_; time_t timeout_;
std::shared_ptr<DHTMessageCallback> callback_; std::unique_ptr<DHTMessageCallback> callback_;
Entry(const std::shared_ptr<DHTMessage>& message, time_t timeout, Entry(std::unique_ptr<DHTMessage> message, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback): std::unique_ptr<DHTMessageCallback> callback)
message_(message), : message_{std::move(message)},
timeout_(timeout), timeout_{timeout},
callback_(callback) {} callback_{std::move(callback)}
{}
}; };
std::deque<Entry> messageQueue_; std::deque<Entry> messageQueue_;
@ -28,23 +28,23 @@ public:
public: public:
MockDHTMessageDispatcher() {} MockDHTMessageDispatcher() {}
virtual ~MockDHTMessageDispatcher() {}
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()) std::unique_ptr<DHTMessageCallback>{})
{ {
messageQueue_.push_back(Entry(message, timeout, callback)); messageQueue_.push_back(Entry(std::move(message), timeout,
std::move(callback)));
} }
virtual void virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message, addMessageToQueue(std::unique_ptr<DHTMessage> message,
const std::shared_ptr<DHTMessageCallback>& callback = std::unique_ptr<DHTMessageCallback> callback =
std::shared_ptr<DHTMessageCallback>()) std::unique_ptr<DHTMessageCallback>{})
{ {
messageQueue_.push_back(Entry(message, DHT_MESSAGE_TIMEOUT, callback)); messageQueue_.push_back(Entry(std::move(message), DHT_MESSAGE_TIMEOUT,
std::move(callback)));
} }
virtual void sendMessages() {} virtual void sendMessages() {}

View file

@ -4,6 +4,15 @@
#include "DHTMessageFactory.h" #include "DHTMessageFactory.h"
#include "DHTNode.h" #include "DHTNode.h"
#include "MockDHTMessage.h" #include "MockDHTMessage.h"
#include "DHTPingMessage.h"
#include "DHTPingReplyMessage.h"
#include "DHTFindNodeMessage.h"
#include "DHTFindNodeReplyMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTGetPeersReplyMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DHTUnknownMessage.h"
namespace aria2 { namespace aria2 {
@ -13,103 +22,99 @@ protected:
public: public:
MockDHTMessageFactory() {} MockDHTMessageFactory() {}
virtual ~MockDHTMessageFactory() {} virtual std::unique_ptr<DHTQueryMessage>
virtual std::shared_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict, createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port) const std::string& ipaddr, uint16_t port)
{ {
return std::shared_ptr<DHTQueryMessage>(); return std::unique_ptr<DHTQueryMessage>{};
} }
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType, createResponseMessage(const std::string& messageType,
const Dict* dict, const Dict* dict,
const std::string& ipaddr, uint16_t port) const std::string& ipaddr, uint16_t port)
{ {
std::shared_ptr<DHTNode> remoteNode(new DHTNode()); auto remoteNode = std::make_shared<DHTNode>();
// TODO At this point, removeNode's ID is random. // TODO At this point, removeNode's ID is random.
remoteNode->setIPAddress(ipaddr); remoteNode->setIPAddress(ipaddr);
remoteNode->setPort(port); remoteNode->setPort(port);
std::shared_ptr<MockDHTResponseMessage> m return make_unique<MockDHTResponseMessage>
(new MockDHTResponseMessage(localNode_, remoteNode, (localNode_, remoteNode, downcast<String>(dict->get("t"))->s());
downcast<String>(dict->get("t"))->s()));
return m;
} }
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = "") const std::string& transactionID = "")
{ {
return std::shared_ptr<DHTQueryMessage>(); return std::unique_ptr<DHTPingMessage>{};
} }
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* remoteNodeID, const unsigned char* remoteNodeID,
const std::string& transactionID) const std::string& transactionID)
{ {
return std::shared_ptr<DHTResponseMessage>(); return std::unique_ptr<DHTPingReplyMessage>{};
} }
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode, createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID, const unsigned char* targetNodeID,
const std::string& transactionID = "") const std::string& transactionID = "")
{ {
return std::shared_ptr<DHTQueryMessage>(); return std::unique_ptr<DHTFindNodeMessage>{};
} }
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) const std::string& transactionID)
{ {
return std::shared_ptr<DHTResponseMessage>(); return std::unique_ptr<DHTFindNodeReplyMessage>{};
} }
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode, createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
const std::string& transactionID) const std::string& transactionID)
{ {
return std::shared_ptr<DHTQueryMessage>(); return std::unique_ptr<DHTGetPeersMessage>{};
} }
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, (const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes, std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers, std::vector<std::shared_ptr<Peer>> peers,
const std::string& token, const std::string& token,
const std::string& transactionID) const std::string& transactionID)
{ {
return std::shared_ptr<DHTResponseMessage>(); return std::unique_ptr<DHTGetPeersReplyMessage>{};
} }
virtual std::shared_ptr<DHTQueryMessage> virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash, const unsigned char* infoHash,
uint16_t tcpPort, uint16_t tcpPort,
const std::string& token, const std::string& token,
const std::string& transactionID = "") const std::string& transactionID = "")
{ {
return std::shared_ptr<DHTQueryMessage>(); return std::unique_ptr<DHTAnnouncePeerMessage>{};
} }
virtual std::shared_ptr<DHTResponseMessage> virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode, createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID) const std::string& transactionID)
{ {
return std::shared_ptr<DHTResponseMessage>(); return std::unique_ptr<DHTAnnouncePeerReplyMessage>{};
} }
virtual std::shared_ptr<DHTMessage> virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length, createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port) const std::string& ipaddr, uint16_t port)
{ {
return std::shared_ptr<DHTMessage>(); return std::unique_ptr<DHTUnknownMessage>{};
} }
void setLocalNode(const std::shared_ptr<DHTNode>& node) void setLocalNode(const std::shared_ptr<DHTNode>& node)