Rewritten ExtensionMessageRegistry

This commit is contained in:
Tatsuhiro Tsujikawa 2012-09-26 22:02:48 +09:00
parent f0000a8754
commit c13dc166de
20 changed files with 321 additions and 143 deletions

View file

@ -36,10 +36,7 @@
#define D_BT_CONSTANTS_H #define D_BT_CONSTANTS_H
#include "common.h" #include "common.h"
#include <map> #include <vector>
#include <string>
typedef std::map<std::string, uint8_t> Extensions;
#define INFO_HASH_LENGTH 20 #define INFO_HASH_LENGTH 20

View file

@ -140,7 +140,8 @@ BtMessageHandle DefaultBtInteractive::receiveHandshake(bool quickReply) {
if(message->isExtendedMessagingEnabled()) { if(message->isExtendedMessagingEnabled()) {
peer_->setExtendedMessagingEnabled(true); peer_->setExtendedMessagingEnabled(true);
if(!utPexEnabled_) { if(!utPexEnabled_) {
extensionMessageRegistry_->removeExtension("ut_pex"); extensionMessageRegistry_->removeExtension
(ExtensionMessageRegistry::UT_PEX);
} }
A2_LOG_INFO(fmt(MSG_EXTENDED_MESSAGING_ENABLED, cuid_)); A2_LOG_INFO(fmt(MSG_EXTENDED_MESSAGING_ENABLED, cuid_));
} }
@ -472,7 +473,8 @@ void DefaultBtInteractive::addPeerExchangeMessage()
if(pexTimer_. if(pexTimer_.
difference(global::wallclock()) >= UTPexExtensionMessage::DEFAULT_INTERVAL) { difference(global::wallclock()) >= UTPexExtensionMessage::DEFAULT_INTERVAL) {
UTPexExtensionMessageHandle m UTPexExtensionMessageHandle m
(new UTPexExtensionMessage(peer_->getExtensionMessageID("ut_pex"))); (new UTPexExtensionMessage(peer_->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX)));
std::vector<SharedHandle<Peer> > activePeers; std::vector<SharedHandle<Peer> > activePeers;
peerStorage_->getActivePeers(activePeers); peerStorage_->getActivePeers(activePeers);
@ -508,7 +510,7 @@ void DefaultBtInteractive::doInteractionProcessing() {
// HandshakeExtensionMessage::doReceivedAction(). // HandshakeExtensionMessage::doReceivedAction().
pieceStorage_ = pieceStorage_ =
downloadContext_->getOwnerRequestGroup()->getPieceStorage(); downloadContext_->getOwnerRequestGroup()->getPieceStorage();
if(peer_->getExtensionMessageID("ut_metadata") && if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA) &&
downloadContext_->getTotalLength() > 0) { downloadContext_->getTotalLength() > 0) {
size_t num = utMetadataRequestTracker_->avail(); size_t num = utMetadataRequestTracker_->avail();
if(num > 0) { if(num > 0) {
@ -549,7 +551,8 @@ void DefaultBtInteractive::doInteractionProcessing() {
addRequests(); addRequests();
} }
} }
if(peer_->getExtensionMessageID("ut_pex") && utPexEnabled_) { if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_PEX) &&
utPexEnabled_) {
addPeerExchangeMessage(); addPeerExchangeMessage();
} }

View file

@ -33,6 +33,9 @@
*/ */
/* copyright --> */ /* copyright --> */
#include "DefaultExtensionMessageFactory.h" #include "DefaultExtensionMessageFactory.h"
#include <cstring>
#include "Peer.h" #include "Peer.h"
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "HandshakeExtensionMessage.h" #include "HandshakeExtensionMessage.h"
@ -81,19 +84,19 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
m->setDownloadContext(dctx_); m->setDownloadContext(dctx_);
return m; return m;
} else { } else {
std::string extensionName = registry_->getExtensionName(extensionMessageID); const char* extensionName = registry_->getExtensionName(extensionMessageID);
if(extensionName.empty()) { if(!extensionName) {
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt("No extension registered for extended message ID %u", (fmt("No extension registered for extended message ID %u",
extensionMessageID)); extensionMessageID));
} }
if(extensionName == "ut_pex") { if(strcmp(extensionName, "ut_pex") == 0) {
// uTorrent compatible Peer-Exchange // uTorrent compatible Peer-Exchange
UTPexExtensionMessageHandle m = UTPexExtensionMessageHandle m =
UTPexExtensionMessage::create(data, length); UTPexExtensionMessage::create(data, length);
m->setPeerStorage(peerStorage_); m->setPeerStorage(peerStorage_);
return m; return m;
} else if(extensionName == "ut_metadata") { } else if(strcmp(extensionName, "ut_metadata") == 0) {
if(length == 0) { if(length == 0) {
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
@ -160,7 +163,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt("Unsupported extension message received." (fmt("Unsupported extension message received."
" extensionMessageID=%u, extensionName=%s", " extensionMessageID=%u, extensionName=%s",
extensionMessageID, extensionName.c_str())); extensionMessageID, extensionName));
} }
} }
} }

View file

@ -2,7 +2,7 @@
/* /*
* aria2 - The high speed download utility * aria2 - The high speed download utility
* *
* Copyright (C) 2010 Tatsuhiro Tsujikawa * Copyright (C) 2012 Tatsuhiro Tsujikawa
* *
* This program is free software; you can redistribute it and/or modify * This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by * it under the terms of the GNU General Public License as published by
@ -33,46 +33,81 @@
*/ */
/* copyright --> */ /* copyright --> */
#include "ExtensionMessageRegistry.h" #include "ExtensionMessageRegistry.h"
#include "BtConstants.h"
#include "A2STR.h" #include <cstring>
#include <cassert>
namespace aria2 { namespace aria2 {
ExtensionMessageRegistry::ExtensionMessageRegistry() ExtensionMessageRegistry::ExtensionMessageRegistry()
{ : extensions_(MAX_EXTENSION)
extensions_["ut_pex"] = 8; {}
// http://www.bittorrent.org/beps/bep_0009.html
extensions_["ut_metadata"] = 9;
}
ExtensionMessageRegistry::~ExtensionMessageRegistry() {} ExtensionMessageRegistry::~ExtensionMessageRegistry() {}
uint8_t ExtensionMessageRegistry::getExtensionMessageID namespace {
(const std::string& name) const const char* EXTENSION_NAMES[] = {
"ut_metadata",
"ut_pex",
0
};
} // namespace
uint8_t ExtensionMessageRegistry::getExtensionMessageID(int key) const
{ {
Extensions::const_iterator itr = extensions_.find(name); assert(key < MAX_EXTENSION);
if(itr == extensions_.end()) { return extensions_[key];
return 0;
} else {
return (*itr).second;
}
} }
const std::string& ExtensionMessageRegistry::getExtensionName(uint8_t id) const const char* ExtensionMessageRegistry::getExtensionName(uint8_t id) const
{ {
for(Extensions::const_iterator itr = extensions_.begin(), int i;
eoi = extensions_.end(); itr != eoi; ++itr) { if(id == 0) {
const Extensions::value_type& p = *itr; return 0;
if(p.second == id) { }
return p.first; for(i = 0; i < MAX_EXTENSION; ++i) {
if(extensions_[i] == id) {
break;
} }
} }
return A2STR::NIL; return EXTENSION_NAMES[i];
} }
void ExtensionMessageRegistry::removeExtension(const std::string& name) void ExtensionMessageRegistry::setExtensionMessageID(int key, uint8_t id)
{ {
extensions_.erase(name); assert(key < MAX_EXTENSION);
extensions_[key] = id;
}
void ExtensionMessageRegistry::removeExtension(int key)
{
assert(key < MAX_EXTENSION);
extensions_[key] = 0;
}
void ExtensionMessageRegistry::setExtensions(const Extensions& extensions)
{
extensions_ = extensions;
}
const char* strBtExtension(int key)
{
if(key >= ExtensionMessageRegistry::MAX_EXTENSION) {
return 0;
} else {
return EXTENSION_NAMES[key];
}
}
int keyBtExtension(const char* name)
{
int i;
for(i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
if(strcmp(EXTENSION_NAMES[i], name) == 0) {
break;
}
}
return i;
} }
} // namespace aria2 } // namespace aria2

View file

@ -2,7 +2,7 @@
/* /*
* aria2 - The high speed download utility * aria2 - The high speed download utility
* *
* Copyright (C) 2009 Tatsuhiro Tsujikawa * Copyright (C) 2012 Tatsuhiro Tsujikawa
* *
* This program is free software; you can redistribute it and/or modify * This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by * it under the terms of the GNU General Public License as published by
@ -37,16 +37,27 @@
#include "common.h" #include "common.h"
#include <string> #include <vector>
#include "BtConstants.h"
namespace aria2 { namespace aria2 {
typedef std::vector<int> Extensions;
// This class stores mapping between BitTorrent entension name and its
// ID. The BitTorrent Extension Protocol is specified in BEP10. This
// class is defined to only stores extensions aria2 supports. See
// InterestingExtension for supported extensions.
//
// See also http://bittorrent.org/beps/bep_0010.html
class ExtensionMessageRegistry { class ExtensionMessageRegistry {
private:
Extensions extensions_;
public: public:
enum InterestingExtension {
UT_METADATA,
UT_PEX,
// The number of extensions.
MAX_EXTENSION
};
ExtensionMessageRegistry(); ExtensionMessageRegistry();
~ExtensionMessageRegistry(); ~ExtensionMessageRegistry();
@ -56,13 +67,38 @@ public:
return extensions_; return extensions_;
} }
uint8_t getExtensionMessageID(const std::string& name) const; void setExtensions(const Extensions& extensions);
const std::string& getExtensionName(uint8_t id) const; // Returns message ID corresponding the given |key|. The |key| must
// be one of InterestingExtension other than MAX_EXTENSION. If
// message ID is not defined, returns 0.
uint8_t getExtensionMessageID(int key) const;
void removeExtension(const std::string& name); // Returns extension name corresponding to the given |id|. If no
// extension is defined for the given |id|, returns NULL.
const char* getExtensionName(uint8_t id) const;
// Sets association of the |key| and |id|. The |key| must be one of
// InterestingExtension other than MAX_EXTENSION.
void setExtensionMessageID(int key, uint8_t id);
// Removes association of the |key|. The |key| must be one of
// InterestingExtension other than MAX_EXTENSION. After this call,
// getExtensionMessageID(key) returns 0.
void removeExtension(int key);
private:
Extensions extensions_;
}; };
// Returns the extension name corresponding to the given |key|. The
// |key| must be one of InterestingExtension other than MAX_EXTENSION.
const char* strBtExtension(int key);
// Returns extension key corresponding to the given extension |name|.
// If no such key exists, returns
// ExtensionMessageRegistry::MAX_EXTENSION.
int keyBtExtension(const char* name);
} // namespace aria2 } // namespace aria2
#endif // D_EXTENSION_MESSAGE_REGISTRY_H #endif // D_EXTENSION_MESSAGE_REGISTRY_H

View file

@ -68,10 +68,11 @@ std::string HandshakeExtensionMessage::getPayload()
dict.put("p", Integer::g(tcpPort_)); dict.put("p", Integer::g(tcpPort_));
} }
SharedHandle<Dict> extDict = Dict::g(); SharedHandle<Dict> extDict = Dict::g();
for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(), for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
eoi = extensions_.end(); itr != eoi; ++itr) { int id = extreg_.getExtensionMessageID(i);
const std::map<std::string, uint8_t>::value_type& vt = *itr; if(id) {
extDict->put(vt.first, Integer::g(vt.second)); extDict->put(strBtExtension(i), Integer::g(id));
}
} }
dict.put("m", extDict); dict.put("m", extDict);
if(metadataSize_) { if(metadataSize_) {
@ -87,10 +88,11 @@ std::string HandshakeExtensionMessage::toString() const
util::percentEncode(clientVersion_).c_str(), util::percentEncode(clientVersion_).c_str(),
tcpPort_, tcpPort_,
static_cast<unsigned long>(metadataSize_))); static_cast<unsigned long>(metadataSize_)));
for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(), for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
eoi = extensions_.end(); itr != eoi; ++itr) { int id = extreg_.getExtensionMessageID(i);
const std::map<std::string, uint8_t>::value_type& vt = *itr; if(id) {
s += fmt(", %s=%u", vt.first.c_str(), vt.second); s += fmt(", %s=%u", strBtExtension(i), id);
}
} }
return s; return s;
} }
@ -101,14 +103,15 @@ void HandshakeExtensionMessage::doReceivedAction()
peer_->setPort(tcpPort_); peer_->setPort(tcpPort_);
peer_->setIncomingPeer(false); peer_->setIncomingPeer(false);
} }
for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(), for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
eoi = extensions_.end(); itr != eoi; ++itr) { int id = extreg_.getExtensionMessageID(i);
const std::map<std::string, uint8_t>::value_type& vt = *itr; if(id) {
peer_->setExtension(vt.first, vt.second); peer_->setExtension(i, id);
}
} }
SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_); SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_);
if(attrs->metadata.empty()) { if(attrs->metadata.empty()) {
if(!peer_->getExtensionMessageID("ut_metadata")) { if(!peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)) {
// TODO In metadataGetMode, if peer don't support metadata // TODO In metadataGetMode, if peer don't support metadata
// transfer, should we drop connection? There is a possibility // transfer, should we drop connection? There is a possibility
// that peer can still tell us peers using PEX. // that peer can still tell us peers using PEX.
@ -146,14 +149,19 @@ void HandshakeExtensionMessage::setPeer(const SharedHandle<Peer>& peer)
peer_ = peer; peer_ = peer;
} }
uint8_t HandshakeExtensionMessage::getExtensionMessageID(const std::string& name) const void HandshakeExtensionMessage::setExtension(int key, uint8_t id)
{ {
std::map<std::string, uint8_t>::const_iterator i = extensions_.find(name); extreg_.setExtensionMessageID(key, id);
if(i == extensions_.end()) { }
return 0;
} else { void HandshakeExtensionMessage::setExtensions(const Extensions& extensions)
return (*i).second; {
} extreg_.setExtensions(extensions);
}
uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
{
return extreg_.getExtensionMessageID(key);
} }
HandshakeExtensionMessageHandle HandshakeExtensionMessageHandle
@ -187,7 +195,13 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
eoi = extDict->end(); i != eoi; ++i) { eoi = extDict->end(); i != eoi; ++i) {
const Integer* extId = downcast<Integer>((*i).second); const Integer* extId = downcast<Integer>((*i).second);
if(extId) { if(extId) {
msg->extensions_[(*i).first] = extId->i(); int key = keyBtExtension((*i).first.c_str());
if(key == ExtensionMessageRegistry::MAX_EXTENSION) {
A2_LOG_DEBUG(fmt("Unsupported BitTorrent extension %s=%" PRId64,
(*i).first.c_str(), extId->i()));
} else {
msg->setExtension(key, extId->i());
}
} }
} }
} }

View file

@ -37,9 +37,8 @@
#include "ExtensionMessage.h" #include "ExtensionMessage.h"
#include <map>
#include "BtConstants.h" #include "BtConstants.h"
#include "ExtensionMessageRegistry.h"
namespace aria2 { namespace aria2 {
@ -54,7 +53,7 @@ private:
size_t metadataSize_; size_t metadataSize_;
std::map<std::string, uint8_t> extensions_; ExtensionMessageRegistry extreg_;
SharedHandle<DownloadContext> dctx_; SharedHandle<DownloadContext> dctx_;
@ -117,17 +116,11 @@ public:
dctx_ = dctx; dctx_ = dctx;
} }
void setExtension(const std::string& name, uint8_t id) void setExtension(int key, uint8_t id);
{
extensions_[name] = id;
}
void setExtensions(const Extensions& extensions) void setExtensions(const Extensions& extensions);
{
extensions_ = extensions;
}
uint8_t getExtensionMessageID(const std::string& name) const; uint8_t getExtensionMessageID(int key) const;
void setPeer(const SharedHandle<Peer>& peer); void setPeer(const SharedHandle<Peer>& peer);

View file

@ -334,22 +334,22 @@ bool Peer::isGood() const
difference(global::wallclock()) >= BAD_CONDITION_INTERVAL; difference(global::wallclock()) >= BAD_CONDITION_INTERVAL;
} }
uint8_t Peer::getExtensionMessageID(const std::string& name) const uint8_t Peer::getExtensionMessageID(int key) const
{ {
assert(res_); assert(res_);
return res_->getExtensionMessageID(name); return res_->getExtensionMessageID(key);
} }
std::string Peer::getExtensionName(uint8_t id) const const char* Peer::getExtensionName(uint8_t id) const
{ {
assert(res_); assert(res_);
return res_->getExtensionName(id); return res_->getExtensionName(id);
} }
void Peer::setExtension(const std::string& name, uint8_t id) void Peer::setExtension(int key, uint8_t id)
{ {
assert(res_); assert(res_);
res_->addExtension(name, id); res_->addExtension(key, id);
} }
void Peer::setExtendedMessagingEnabled(bool enabled) void Peer::setExtendedMessagingEnabled(bool enabled)

View file

@ -283,11 +283,11 @@ public:
bool hasPiece(size_t index) const; bool hasPiece(size_t index) const;
uint8_t getExtensionMessageID(const std::string& name) const; uint8_t getExtensionMessageID(int key) const;
std::string getExtensionName(uint8_t id) const; const char* getExtensionName(uint8_t id) const;
void setExtension(const std::string& name, uint8_t id); void setExtension(int key, uint8_t id);
const Timer& getLastDownloadUpdate() const; const Timer& getLastDownloadUpdate() const;

View file

@ -120,6 +120,10 @@ PeerInteractionCommand::PeerInteractionCommand
SharedHandle<ExtensionMessageRegistry> exMsgRegistry SharedHandle<ExtensionMessageRegistry> exMsgRegistry
(new ExtensionMessageRegistry()); (new ExtensionMessageRegistry());
exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 8);
// http://www.bittorrent.org/beps/bep_0009.html
exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA,
9);
SharedHandle<UTMetadataRequestFactory> utMetadataRequestFactory; SharedHandle<UTMetadataRequestFactory> utMetadataRequestFactory;
SharedHandle<UTMetadataRequestTracker> utMetadataRequestTracker; SharedHandle<UTMetadataRequestTracker> utMetadataRequestTracker;

View file

@ -192,32 +192,19 @@ void PeerSessionResource::extendedMessagingEnabled(bool b)
extendedMessagingEnabled_ = b; extendedMessagingEnabled_ = b;
} }
uint8_t uint8_t PeerSessionResource::getExtensionMessageID(int key) const
PeerSessionResource::getExtensionMessageID(const std::string& name) const
{ {
Extensions::const_iterator itr = extensions_.find(name); return extreg_.getExtensionMessageID(key);
if(itr == extensions_.end()) {
return 0;
} else {
return (*itr).second;
}
} }
std::string PeerSessionResource::getExtensionName(uint8_t id) const const char* PeerSessionResource::getExtensionName(uint8_t id) const
{ {
for(Extensions::const_iterator itr = extensions_.begin(), return extreg_.getExtensionName(id);
eoi = extensions_.end(); itr != eoi; ++itr) {
const Extensions::value_type& p = *itr;
if(p.second == id) {
return p.first;
}
}
return A2STR::NIL;
} }
void PeerSessionResource::addExtension(const std::string& name, uint8_t id) void PeerSessionResource::addExtension(int key, uint8_t id)
{ {
extensions_[name] = id; extreg_.setExtensionMessageID(key, id);
} }
void PeerSessionResource::dhtEnabled(bool b) void PeerSessionResource::dhtEnabled(bool b)

View file

@ -43,6 +43,7 @@
#include "BtConstants.h" #include "BtConstants.h"
#include "PeerStat.h" #include "PeerStat.h"
#include "TimerA2.h" #include "TimerA2.h"
#include "ExtensionMessageRegistry.h"
namespace aria2 { namespace aria2 {
@ -73,7 +74,7 @@ private:
// fast index set which localhost has sent to a peer. // fast index set which localhost has sent to a peer.
std::set<size_t> amAllowedIndexSet_; std::set<size_t> amAllowedIndexSet_;
bool extendedMessagingEnabled_; bool extendedMessagingEnabled_;
Extensions extensions_; ExtensionMessageRegistry extreg_;
bool dhtEnabled_; bool dhtEnabled_;
PeerStat peerStat_; PeerStat peerStat_;
@ -192,11 +193,11 @@ public:
void extendedMessagingEnabled(bool b); void extendedMessagingEnabled(bool b);
uint8_t getExtensionMessageID(const std::string& name) const; uint8_t getExtensionMessageID(int key) const;
std::string getExtensionName(uint8_t id) const; const char* getExtensionName(uint8_t id) const;
void addExtension(const std::string& name, uint8_t id); void addExtension(int key, uint8_t id);
bool dhtEnabled() const bool dhtEnabled() const
{ {

View file

@ -48,6 +48,7 @@
#include "DownloadContext.h" #include "DownloadContext.h"
#include "BtMessage.h" #include "BtMessage.h"
#include "PieceStorage.h" #include "PieceStorage.h"
#include "ExtensionMessageRegistry.h"
namespace aria2 { namespace aria2 {
@ -76,7 +77,8 @@ std::string UTMetadataRequestExtensionMessage::toString() const
void UTMetadataRequestExtensionMessage::doReceivedAction() void UTMetadataRequestExtensionMessage::doReceivedAction()
{ {
SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_); SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_);
uint8_t id = peer_->getExtensionMessageID("ut_metadata"); uint8_t id = peer_->getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA);
if(attrs->metadata.empty()) { if(attrs->metadata.empty()) {
SharedHandle<UTMetadataRejectExtensionMessage> m SharedHandle<UTMetadataRejectExtensionMessage> m
(new UTMetadataRejectExtensionMessage(id)); (new UTMetadataRejectExtensionMessage(id));

View file

@ -44,6 +44,7 @@
#include "Logger.h" #include "Logger.h"
#include "LogFactory.h" #include "LogFactory.h"
#include "fmt.h" #include "fmt.h"
#include "ExtensionMessageRegistry.h"
namespace aria2 { namespace aria2 {
@ -71,7 +72,7 @@ void UTMetadataRequestFactory::create
static_cast<unsigned long>(p->getIndex()))); static_cast<unsigned long>(p->getIndex())));
SharedHandle<UTMetadataRequestExtensionMessage> m SharedHandle<UTMetadataRequestExtensionMessage> m
(new UTMetadataRequestExtensionMessage (new UTMetadataRequestExtensionMessage
(peer_->getExtensionMessageID("ut_metadata"))); (peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)));
m->setIndex(p->getIndex()); m->setIndex(p->getIndex());
m->setDownloadContext(dctx_); m->setDownloadContext(dctx_);
m->setBtMessageDispatcher(dispatcher_); m->setBtMessageDispatcher(dispatcher_);

View file

@ -53,7 +53,7 @@ public:
peer_.reset(new Peer("192.168.0.1", 6969)); peer_.reset(new Peer("192.168.0.1", 6969));
peer_->allocateSessionResource(1024, 1024*1024); peer_->allocateSessionResource(1024, 1024*1024);
peer_->setExtension("ut_pex", 1); peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 1);
registry_.reset(new ExtensionMessageRegistry()); registry_.reset(new ExtensionMessageRegistry());
@ -76,9 +76,9 @@ public:
factory_->setDownloadContext(dctx_); factory_->setDownloadContext(dctx_);
} }
std::string getExtensionMessageID(const std::string& name) std::string getExtensionMessageID(int key)
{ {
unsigned char id[1] = { registry_->getExtensionMessageID(name) }; unsigned char id[1] = { registry_->getExtensionMessageID(key) };
return std::string(&id[0], &id[1]); return std::string(&id[0], &id[1]);
} }
@ -103,7 +103,7 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DefaultExtensionMessageFactoryTest);
void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown() void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown()
{ {
peer_->setExtension("foo", 255); peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 255);
unsigned char id[1] = { 255 }; unsigned char id[1] = { 255 };
@ -139,7 +139,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
bittorrent::packcompact(c3, "192.168.0.2", 6882); bittorrent::packcompact(c3, "192.168.0.2", 6882);
bittorrent::packcompact(c4, "10.1.1.3",10000); bittorrent::packcompact(c4, "10.1.1.3",10000);
std::string data = getExtensionMessageID("ut_pex")+"d5:added12:"+ registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1);
std::string data = getExtensionMessageID(ExtensionMessageRegistry::UT_PEX)
+"d5:added12:"+
std::string(&c1[0], &c1[6])+std::string(&c2[0], &c2[6])+ std::string(&c1[0], &c1[6])+std::string(&c2[0], &c2[6])+
"7:added.f2:207:dropped12:"+ "7:added.f2:207:dropped12:"+
std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+ std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+
@ -147,13 +150,17 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
SharedHandle<UTPexExtensionMessage> m = SharedHandle<UTPexExtensionMessage> m =
createMessage<UTPexExtensionMessage>(data); createMessage<UTPexExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID("ut_pex"), CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX),
m->getExtensionMessageID()); m->getExtensionMessageID());
} }
void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest() void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
{ {
std::string data = getExtensionMessageID("ut_metadata")+ registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei0e5:piecei1ee"; "d8:msg_typei0e5:piecei1ee";
SharedHandle<UTMetadataRequestExtensionMessage> m = SharedHandle<UTMetadataRequestExtensionMessage> m =
createMessage<UTMetadataRequestExtensionMessage>(data); createMessage<UTMetadataRequestExtensionMessage>(data);
@ -162,7 +169,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData() void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
{ {
std::string data = getExtensionMessageID("ut_metadata")+ registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000"; "d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000";
SharedHandle<UTMetadataDataExtensionMessage> m = SharedHandle<UTMetadataDataExtensionMessage> m =
createMessage<UTMetadataDataExtensionMessage>(data); createMessage<UTMetadataDataExtensionMessage>(data);
@ -173,7 +183,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject() void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject()
{ {
std::string data = getExtensionMessageID("ut_metadata")+ registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei2e5:piecei1ee"; "d8:msg_typei2e5:piecei1ee";
SharedHandle<UTMetadataRejectExtensionMessage> m = SharedHandle<UTMetadataRejectExtensionMessage> m =
createMessage<UTMetadataRejectExtensionMessage>(data); createMessage<UTMetadataRejectExtensionMessage>(data);

View file

@ -0,0 +1,75 @@
#include "ExtensionMessageRegistry.h"
#include <cppunit/extensions/HelperMacros.h>
namespace aria2 {
class ExtensionMessageRegistryTest:public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ExtensionMessageRegistryTest);
CPPUNIT_TEST(testStrBtExtension);
CPPUNIT_TEST(testKeyBtExtension);
CPPUNIT_TEST(testGetExtensionMessageID);
CPPUNIT_TEST_SUITE_END();
public:
void testStrBtExtension();
void testKeyBtExtension();
void testGetExtensionMessageID();
};
CPPUNIT_TEST_SUITE_REGISTRATION( ExtensionMessageRegistryTest );
void ExtensionMessageRegistryTest::testStrBtExtension()
{
CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
std::string(strBtExtension
(ExtensionMessageRegistry::UT_PEX)));
CPPUNIT_ASSERT_EQUAL(std::string("ut_metadata"),
std::string(strBtExtension
(ExtensionMessageRegistry::UT_METADATA)));
CPPUNIT_ASSERT(!strBtExtension(100));
}
void ExtensionMessageRegistryTest::testKeyBtExtension()
{
CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_PEX,
keyBtExtension("ut_pex"));
CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_METADATA,
keyBtExtension("ut_metadata"));
CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::MAX_EXTENSION,
keyBtExtension("unknown"));
}
void ExtensionMessageRegistryTest::testGetExtensionMessageID()
{
ExtensionMessageRegistry reg;
CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
CPPUNIT_ASSERT(!reg.getExtensionName(0));
CPPUNIT_ASSERT(!reg.getExtensionName(1));
CPPUNIT_ASSERT(!reg.getExtensionName(100));
reg.setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1);
CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
std::string(reg.getExtensionName(1)));
CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
reg.setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 127);
CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA));
CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
reg.removeExtension(ExtensionMessageRegistry::UT_PEX);
CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA));
CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
CPPUNIT_ASSERT(!reg.getExtensionName(1));
}
} // namespace aria2

View file

@ -60,12 +60,12 @@ void HandshakeExtensionMessageTest::testGetBencodedData()
HandshakeExtensionMessage msg; HandshakeExtensionMessage msg;
msg.setClientVersion("aria2"); msg.setClientVersion("aria2");
msg.setTCPPort(6889); msg.setTCPPort(6889);
msg.setExtension("ut_pex", 1); msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
msg.setExtension("a2_dht", 2); msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2);
msg.setMetadataSize(1024); msg.setMetadataSize(1024);
CPPUNIT_ASSERT_EQUAL CPPUNIT_ASSERT_EQUAL
(std::string("d" (std::string("d"
"1:md6:a2_dhti2e6:ut_pexi1ee" "1:md11:ut_metadatai2e6:ut_pexi1ee"
"13:metadata_sizei1024e" "13:metadata_sizei1024e"
"1:pi6889e" "1:pi6889e"
"1:v5:aria2" "1:v5:aria2"
@ -81,12 +81,12 @@ void HandshakeExtensionMessageTest::testToString()
HandshakeExtensionMessage msg; HandshakeExtensionMessage msg;
msg.setClientVersion("aria2"); msg.setClientVersion("aria2");
msg.setTCPPort(6889); msg.setTCPPort(6889);
msg.setExtension("ut_pex", 1); msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
msg.setExtension("a2_dht", 2); msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2);
msg.setMetadataSize(1024); msg.setMetadataSize(1024);
CPPUNIT_ASSERT_EQUAL CPPUNIT_ASSERT_EQUAL
(std::string("handshake client=aria2, tcpPort=6889, metadataSize=1024," (std::string("handshake client=aria2, tcpPort=6889, metadataSize=1024,"
" a2_dht=2, ut_pex=1"), msg.toString()); " ut_metadata=2, ut_pex=1"), msg.toString());
} }
void HandshakeExtensionMessageTest::testDoReceivedAction() void HandshakeExtensionMessageTest::testDoReceivedAction()
@ -106,9 +106,8 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
HandshakeExtensionMessage msg; HandshakeExtensionMessage msg;
msg.setClientVersion("aria2"); msg.setClientVersion("aria2");
msg.setTCPPort(6889); msg.setTCPPort(6889);
msg.setExtension("ut_pex", 1); msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
msg.setExtension("a2_dht", 2); msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 3);
msg.setExtension("ut_metadata", 3);
msg.setMetadataSize(1024); msg.setMetadataSize(1024);
msg.setPeer(peer); msg.setPeer(peer);
msg.setDownloadContext(dctx); msg.setDownloadContext(dctx);
@ -116,8 +115,12 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
msg.doReceivedAction(); msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort());
CPPUNIT_ASSERT_EQUAL((uint8_t)1, peer->getExtensionMessageID("ut_pex")); CPPUNIT_ASSERT_EQUAL((uint8_t)1,
CPPUNIT_ASSERT_EQUAL((uint8_t)2, peer->getExtensionMessageID("a2_dht")); peer->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
CPPUNIT_ASSERT_EQUAL((uint8_t)3,
peer->getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA));
CPPUNIT_ASSERT(peer->isSeeder()); CPPUNIT_ASSERT(peer->isSeeder());
CPPUNIT_ASSERT_EQUAL((size_t)1024, attrs->metadataSize); CPPUNIT_ASSERT_EQUAL((size_t)1024, attrs->metadataSize);
CPPUNIT_ASSERT_EQUAL((int64_t)1024, dctx->getTotalLength()); CPPUNIT_ASSERT_EQUAL((int64_t)1024, dctx->getTotalLength());
@ -134,13 +137,15 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
void HandshakeExtensionMessageTest::testCreate() void HandshakeExtensionMessageTest::testCreate()
{ {
std::string in = std::string in =
"0d1:pi6881e1:v5:aria21:md6:ut_pexi1ee13:metadata_sizei1024ee"; "0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
SharedHandle<HandshakeExtensionMessage> m = SharedHandle<HandshakeExtensionMessage> m =
HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()), HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()),
in.size()); in.size());
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion()); CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
CPPUNIT_ASSERT_EQUAL((uint8_t)1, m->getExtensionMessageID("ut_pex")); CPPUNIT_ASSERT_EQUAL((uint8_t)1,
m->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
CPPUNIT_ASSERT_EQUAL((size_t)1024, m->getMetadataSize()); CPPUNIT_ASSERT_EQUAL((size_t)1024, m->getMetadataSize());
try { try {
// bad payload format // bad payload format
@ -182,7 +187,9 @@ void HandshakeExtensionMessageTest::testCreate_stringnum()
// port number in string is not allowed // port number in string is not allowed
CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());
// extension ID in string is not allowed // extension ID in string is not allowed
CPPUNIT_ASSERT_EQUAL((uint8_t)0, m->getExtensionMessageID("ut_pex")); CPPUNIT_ASSERT_EQUAL((uint8_t)0,
m->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
} }
} // namespace aria2 } // namespace aria2

View file

@ -208,7 +208,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\
LpdMessageReceiverTest.cc\ LpdMessageReceiverTest.cc\
Bencode2Test.cc\ Bencode2Test.cc\
PeerConnectionTest.cc\ PeerConnectionTest.cc\
ValueBaseBencodeParserTest.cc ValueBaseBencodeParserTest.cc\
ExtensionMessageRegistryTest.cc
endif # ENABLE_BITTORRENT endif # ENABLE_BITTORRENT
if ENABLE_METALINK if ENABLE_METALINK

View file

@ -137,12 +137,17 @@ void PeerSessionResourceTest::testGetExtensionMessageID()
{ {
PeerSessionResource res(1024, 1024*1024); PeerSessionResource res(1024, 1024*1024);
res.addExtension("a2", 9); res.addExtension(ExtensionMessageRegistry::UT_PEX, 9);
CPPUNIT_ASSERT_EQUAL((uint8_t)9, res.getExtensionMessageID("a2")); CPPUNIT_ASSERT_EQUAL((uint8_t)9,
CPPUNIT_ASSERT_EQUAL((uint8_t)0, res.getExtensionMessageID("non")); res.getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX));
CPPUNIT_ASSERT_EQUAL((uint8_t)0,
res.getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA));
CPPUNIT_ASSERT_EQUAL(std::string("a2"), res.getExtensionName(9)); CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
CPPUNIT_ASSERT_EQUAL(std::string(""), res.getExtensionName(10)); std::string(res.getExtensionName(9)));
CPPUNIT_ASSERT(!res.getExtensionName(10));
} }
void PeerSessionResourceTest::testFastExtensionEnabled() void PeerSessionResourceTest::testFastExtensionEnabled()

View file

@ -16,6 +16,7 @@
#include "PieceStorage.h" #include "PieceStorage.h"
#include "extension_message_test_helper.h" #include "extension_message_test_helper.h"
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "ExtensionMessageRegistry.h"
namespace aria2 { namespace aria2 {
@ -44,7 +45,7 @@ public:
dctx_->setAttribute(CTX_ATTR_BT, attrs); dctx_->setAttribute(CTX_ATTR_BT, attrs);
peer_.reset(new Peer("host", 6880)); peer_.reset(new Peer("host", 6880));
peer_->allocateSessionResource(0, 0); peer_->allocateSessionResource(0, 0);
peer_->setExtension("ut_metadata", 1); peer_->setExtension(ExtensionMessageRegistry::UT_METADATA, 1);
} }
template<typename T> template<typename T>