diff --git a/src/BtConstants.h b/src/BtConstants.h index 5507f30f..609f8f7e 100644 --- a/src/BtConstants.h +++ b/src/BtConstants.h @@ -36,10 +36,7 @@ #define D_BT_CONSTANTS_H #include "common.h" -#include -#include - -typedef std::map Extensions; +#include #define INFO_HASH_LENGTH 20 diff --git a/src/DefaultBtInteractive.cc b/src/DefaultBtInteractive.cc index 7b9816bc..4e3948f9 100644 --- a/src/DefaultBtInteractive.cc +++ b/src/DefaultBtInteractive.cc @@ -140,7 +140,8 @@ BtMessageHandle DefaultBtInteractive::receiveHandshake(bool quickReply) { if(message->isExtendedMessagingEnabled()) { peer_->setExtendedMessagingEnabled(true); if(!utPexEnabled_) { - extensionMessageRegistry_->removeExtension("ut_pex"); + extensionMessageRegistry_->removeExtension + (ExtensionMessageRegistry::UT_PEX); } A2_LOG_INFO(fmt(MSG_EXTENDED_MESSAGING_ENABLED, cuid_)); } @@ -472,7 +473,8 @@ void DefaultBtInteractive::addPeerExchangeMessage() if(pexTimer_. difference(global::wallclock()) >= UTPexExtensionMessage::DEFAULT_INTERVAL) { UTPexExtensionMessageHandle m - (new UTPexExtensionMessage(peer_->getExtensionMessageID("ut_pex"))); + (new UTPexExtensionMessage(peer_->getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX))); std::vector > activePeers; peerStorage_->getActivePeers(activePeers); @@ -508,7 +510,7 @@ void DefaultBtInteractive::doInteractionProcessing() { // HandshakeExtensionMessage::doReceivedAction(). pieceStorage_ = downloadContext_->getOwnerRequestGroup()->getPieceStorage(); - if(peer_->getExtensionMessageID("ut_metadata") && + if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA) && downloadContext_->getTotalLength() > 0) { size_t num = utMetadataRequestTracker_->avail(); if(num > 0) { @@ -549,7 +551,8 @@ void DefaultBtInteractive::doInteractionProcessing() { addRequests(); } } - if(peer_->getExtensionMessageID("ut_pex") && utPexEnabled_) { + if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_PEX) && + utPexEnabled_) { addPeerExchangeMessage(); } diff --git a/src/DefaultExtensionMessageFactory.cc b/src/DefaultExtensionMessageFactory.cc index a86765c0..c84c6789 100644 --- a/src/DefaultExtensionMessageFactory.cc +++ b/src/DefaultExtensionMessageFactory.cc @@ -33,6 +33,9 @@ */ /* copyright --> */ #include "DefaultExtensionMessageFactory.h" + +#include + #include "Peer.h" #include "DlAbortEx.h" #include "HandshakeExtensionMessage.h" @@ -81,19 +84,19 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t m->setDownloadContext(dctx_); return m; } else { - std::string extensionName = registry_->getExtensionName(extensionMessageID); - if(extensionName.empty()) { + const char* extensionName = registry_->getExtensionName(extensionMessageID); + if(!extensionName) { throw DL_ABORT_EX (fmt("No extension registered for extended message ID %u", extensionMessageID)); } - if(extensionName == "ut_pex") { + if(strcmp(extensionName, "ut_pex") == 0) { // uTorrent compatible Peer-Exchange UTPexExtensionMessageHandle m = UTPexExtensionMessage::create(data, length); m->setPeerStorage(peerStorage_); return m; - } else if(extensionName == "ut_metadata") { + } else if(strcmp(extensionName, "ut_metadata") == 0) { if(length == 0) { throw DL_ABORT_EX (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, @@ -160,7 +163,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t throw DL_ABORT_EX (fmt("Unsupported extension message received." " extensionMessageID=%u, extensionName=%s", - extensionMessageID, extensionName.c_str())); + extensionMessageID, extensionName)); } } } diff --git a/src/ExtensionMessageRegistry.cc b/src/ExtensionMessageRegistry.cc index b4100930..8197eb1e 100644 --- a/src/ExtensionMessageRegistry.cc +++ b/src/ExtensionMessageRegistry.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2010 Tatsuhiro Tsujikawa + * Copyright (C) 2012 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -33,46 +33,81 @@ */ /* copyright --> */ #include "ExtensionMessageRegistry.h" -#include "BtConstants.h" -#include "A2STR.h" + +#include +#include namespace aria2 { ExtensionMessageRegistry::ExtensionMessageRegistry() -{ - extensions_["ut_pex"] = 8; - // http://www.bittorrent.org/beps/bep_0009.html - extensions_["ut_metadata"] = 9; -} + : extensions_(MAX_EXTENSION) +{} ExtensionMessageRegistry::~ExtensionMessageRegistry() {} -uint8_t ExtensionMessageRegistry::getExtensionMessageID -(const std::string& name) const +namespace { +const char* EXTENSION_NAMES[] = { + "ut_metadata", + "ut_pex", + 0 +}; +} // namespace + +uint8_t ExtensionMessageRegistry::getExtensionMessageID(int key) const { - Extensions::const_iterator itr = extensions_.find(name); - if(itr == extensions_.end()) { - return 0; - } else { - return (*itr).second; - } + assert(key < MAX_EXTENSION); + return extensions_[key]; } -const std::string& ExtensionMessageRegistry::getExtensionName(uint8_t id) const +const char* ExtensionMessageRegistry::getExtensionName(uint8_t id) const { - for(Extensions::const_iterator itr = extensions_.begin(), - eoi = extensions_.end(); itr != eoi; ++itr) { - const Extensions::value_type& p = *itr; - if(p.second == id) { - return p.first; + int i; + if(id == 0) { + return 0; + } + for(i = 0; i < MAX_EXTENSION; ++i) { + if(extensions_[i] == id) { + break; } } - return A2STR::NIL; + return EXTENSION_NAMES[i]; } -void ExtensionMessageRegistry::removeExtension(const std::string& name) +void ExtensionMessageRegistry::setExtensionMessageID(int key, uint8_t id) { - extensions_.erase(name); + assert(key < MAX_EXTENSION); + extensions_[key] = id; +} + +void ExtensionMessageRegistry::removeExtension(int key) +{ + assert(key < MAX_EXTENSION); + extensions_[key] = 0; +} + +void ExtensionMessageRegistry::setExtensions(const Extensions& extensions) +{ + extensions_ = extensions; +} + +const char* strBtExtension(int key) +{ + if(key >= ExtensionMessageRegistry::MAX_EXTENSION) { + return 0; + } else { + return EXTENSION_NAMES[key]; + } +} + +int keyBtExtension(const char* name) +{ + int i; + for(i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) { + if(strcmp(EXTENSION_NAMES[i], name) == 0) { + break; + } + } + return i; } } // namespace aria2 diff --git a/src/ExtensionMessageRegistry.h b/src/ExtensionMessageRegistry.h index 91b118ff..50babcdd 100644 --- a/src/ExtensionMessageRegistry.h +++ b/src/ExtensionMessageRegistry.h @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2009 Tatsuhiro Tsujikawa + * Copyright (C) 2012 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -37,16 +37,27 @@ #include "common.h" -#include - -#include "BtConstants.h" +#include namespace aria2 { +typedef std::vector Extensions; + +// This class stores mapping between BitTorrent entension name and its +// ID. The BitTorrent Extension Protocol is specified in BEP10. This +// class is defined to only stores extensions aria2 supports. See +// InterestingExtension for supported extensions. +// +// See also http://bittorrent.org/beps/bep_0010.html class ExtensionMessageRegistry { -private: - Extensions extensions_; public: + enum InterestingExtension { + UT_METADATA, + UT_PEX, + // The number of extensions. + MAX_EXTENSION + }; + ExtensionMessageRegistry(); ~ExtensionMessageRegistry(); @@ -56,13 +67,38 @@ public: return extensions_; } - uint8_t getExtensionMessageID(const std::string& name) const; + void setExtensions(const Extensions& extensions); - const std::string& getExtensionName(uint8_t id) const; + // Returns message ID corresponding the given |key|. The |key| must + // be one of InterestingExtension other than MAX_EXTENSION. If + // message ID is not defined, returns 0. + uint8_t getExtensionMessageID(int key) const; - void removeExtension(const std::string& name); + // Returns extension name corresponding to the given |id|. If no + // extension is defined for the given |id|, returns NULL. + const char* getExtensionName(uint8_t id) const; + + // Sets association of the |key| and |id|. The |key| must be one of + // InterestingExtension other than MAX_EXTENSION. + void setExtensionMessageID(int key, uint8_t id); + + // Removes association of the |key|. The |key| must be one of + // InterestingExtension other than MAX_EXTENSION. After this call, + // getExtensionMessageID(key) returns 0. + void removeExtension(int key); +private: + Extensions extensions_; }; +// Returns the extension name corresponding to the given |key|. The +// |key| must be one of InterestingExtension other than MAX_EXTENSION. +const char* strBtExtension(int key); + +// Returns extension key corresponding to the given extension |name|. +// If no such key exists, returns +// ExtensionMessageRegistry::MAX_EXTENSION. +int keyBtExtension(const char* name); + } // namespace aria2 #endif // D_EXTENSION_MESSAGE_REGISTRY_H diff --git a/src/HandshakeExtensionMessage.cc b/src/HandshakeExtensionMessage.cc index e0860d2a..63afd356 100644 --- a/src/HandshakeExtensionMessage.cc +++ b/src/HandshakeExtensionMessage.cc @@ -68,10 +68,11 @@ std::string HandshakeExtensionMessage::getPayload() dict.put("p", Integer::g(tcpPort_)); } SharedHandle extDict = Dict::g(); - for(std::map::const_iterator itr = extensions_.begin(), - eoi = extensions_.end(); itr != eoi; ++itr) { - const std::map::value_type& vt = *itr; - extDict->put(vt.first, Integer::g(vt.second)); + for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) { + int id = extreg_.getExtensionMessageID(i); + if(id) { + extDict->put(strBtExtension(i), Integer::g(id)); + } } dict.put("m", extDict); if(metadataSize_) { @@ -87,10 +88,11 @@ std::string HandshakeExtensionMessage::toString() const util::percentEncode(clientVersion_).c_str(), tcpPort_, static_cast(metadataSize_))); - for(std::map::const_iterator itr = extensions_.begin(), - eoi = extensions_.end(); itr != eoi; ++itr) { - const std::map::value_type& vt = *itr; - s += fmt(", %s=%u", vt.first.c_str(), vt.second); + for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) { + int id = extreg_.getExtensionMessageID(i); + if(id) { + s += fmt(", %s=%u", strBtExtension(i), id); + } } return s; } @@ -101,14 +103,15 @@ void HandshakeExtensionMessage::doReceivedAction() peer_->setPort(tcpPort_); peer_->setIncomingPeer(false); } - for(std::map::const_iterator itr = extensions_.begin(), - eoi = extensions_.end(); itr != eoi; ++itr) { - const std::map::value_type& vt = *itr; - peer_->setExtension(vt.first, vt.second); + for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) { + int id = extreg_.getExtensionMessageID(i); + if(id) { + peer_->setExtension(i, id); + } } SharedHandle attrs = bittorrent::getTorrentAttrs(dctx_); if(attrs->metadata.empty()) { - if(!peer_->getExtensionMessageID("ut_metadata")) { + if(!peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)) { // TODO In metadataGetMode, if peer don't support metadata // transfer, should we drop connection? There is a possibility // that peer can still tell us peers using PEX. @@ -146,14 +149,19 @@ void HandshakeExtensionMessage::setPeer(const SharedHandle& peer) peer_ = peer; } -uint8_t HandshakeExtensionMessage::getExtensionMessageID(const std::string& name) const +void HandshakeExtensionMessage::setExtension(int key, uint8_t id) { - std::map::const_iterator i = extensions_.find(name); - if(i == extensions_.end()) { - return 0; - } else { - return (*i).second; - } + extreg_.setExtensionMessageID(key, id); +} + +void HandshakeExtensionMessage::setExtensions(const Extensions& extensions) +{ + extreg_.setExtensions(extensions); +} + +uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const +{ + return extreg_.getExtensionMessageID(key); } HandshakeExtensionMessageHandle @@ -187,7 +195,13 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length) eoi = extDict->end(); i != eoi; ++i) { const Integer* extId = downcast((*i).second); if(extId) { - msg->extensions_[(*i).first] = extId->i(); + int key = keyBtExtension((*i).first.c_str()); + if(key == ExtensionMessageRegistry::MAX_EXTENSION) { + A2_LOG_DEBUG(fmt("Unsupported BitTorrent extension %s=%" PRId64, + (*i).first.c_str(), extId->i())); + } else { + msg->setExtension(key, extId->i()); + } } } } diff --git a/src/HandshakeExtensionMessage.h b/src/HandshakeExtensionMessage.h index 50e353ba..d53c5c9d 100644 --- a/src/HandshakeExtensionMessage.h +++ b/src/HandshakeExtensionMessage.h @@ -37,9 +37,8 @@ #include "ExtensionMessage.h" -#include - #include "BtConstants.h" +#include "ExtensionMessageRegistry.h" namespace aria2 { @@ -54,7 +53,7 @@ private: size_t metadataSize_; - std::map extensions_; + ExtensionMessageRegistry extreg_; SharedHandle dctx_; @@ -117,17 +116,11 @@ public: dctx_ = dctx; } - void setExtension(const std::string& name, uint8_t id) - { - extensions_[name] = id; - } + void setExtension(int key, uint8_t id); - void setExtensions(const Extensions& extensions) - { - extensions_ = extensions; - } + void setExtensions(const Extensions& extensions); - uint8_t getExtensionMessageID(const std::string& name) const; + uint8_t getExtensionMessageID(int key) const; void setPeer(const SharedHandle& peer); diff --git a/src/Peer.cc b/src/Peer.cc index 4d5a7dfc..2a4948a4 100644 --- a/src/Peer.cc +++ b/src/Peer.cc @@ -334,22 +334,22 @@ bool Peer::isGood() const difference(global::wallclock()) >= BAD_CONDITION_INTERVAL; } -uint8_t Peer::getExtensionMessageID(const std::string& name) const +uint8_t Peer::getExtensionMessageID(int key) const { assert(res_); - return res_->getExtensionMessageID(name); + return res_->getExtensionMessageID(key); } -std::string Peer::getExtensionName(uint8_t id) const +const char* Peer::getExtensionName(uint8_t id) const { assert(res_); return res_->getExtensionName(id); } -void Peer::setExtension(const std::string& name, uint8_t id) +void Peer::setExtension(int key, uint8_t id) { assert(res_); - res_->addExtension(name, id); + res_->addExtension(key, id); } void Peer::setExtendedMessagingEnabled(bool enabled) diff --git a/src/Peer.h b/src/Peer.h index 6d4cf168..0937ba97 100644 --- a/src/Peer.h +++ b/src/Peer.h @@ -283,11 +283,11 @@ public: bool hasPiece(size_t index) const; - uint8_t getExtensionMessageID(const std::string& name) const; + uint8_t getExtensionMessageID(int key) const; - std::string getExtensionName(uint8_t id) const; + const char* getExtensionName(uint8_t id) const; - void setExtension(const std::string& name, uint8_t id); + void setExtension(int key, uint8_t id); const Timer& getLastDownloadUpdate() const; diff --git a/src/PeerInteractionCommand.cc b/src/PeerInteractionCommand.cc index 592c20d9..dd9dc7e1 100644 --- a/src/PeerInteractionCommand.cc +++ b/src/PeerInteractionCommand.cc @@ -120,6 +120,10 @@ PeerInteractionCommand::PeerInteractionCommand SharedHandle exMsgRegistry (new ExtensionMessageRegistry()); + exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 8); + // http://www.bittorrent.org/beps/bep_0009.html + exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, + 9); SharedHandle utMetadataRequestFactory; SharedHandle utMetadataRequestTracker; diff --git a/src/PeerSessionResource.cc b/src/PeerSessionResource.cc index 5d72f074..26dc3c24 100644 --- a/src/PeerSessionResource.cc +++ b/src/PeerSessionResource.cc @@ -192,32 +192,19 @@ void PeerSessionResource::extendedMessagingEnabled(bool b) extendedMessagingEnabled_ = b; } -uint8_t -PeerSessionResource::getExtensionMessageID(const std::string& name) const +uint8_t PeerSessionResource::getExtensionMessageID(int key) const { - Extensions::const_iterator itr = extensions_.find(name); - if(itr == extensions_.end()) { - return 0; - } else { - return (*itr).second; - } + return extreg_.getExtensionMessageID(key); } -std::string PeerSessionResource::getExtensionName(uint8_t id) const +const char* PeerSessionResource::getExtensionName(uint8_t id) const { - for(Extensions::const_iterator itr = extensions_.begin(), - eoi = extensions_.end(); itr != eoi; ++itr) { - const Extensions::value_type& p = *itr; - if(p.second == id) { - return p.first; - } - } - return A2STR::NIL; + return extreg_.getExtensionName(id); } -void PeerSessionResource::addExtension(const std::string& name, uint8_t id) +void PeerSessionResource::addExtension(int key, uint8_t id) { - extensions_[name] = id; + extreg_.setExtensionMessageID(key, id); } void PeerSessionResource::dhtEnabled(bool b) diff --git a/src/PeerSessionResource.h b/src/PeerSessionResource.h index aab6ebbc..c0d19261 100644 --- a/src/PeerSessionResource.h +++ b/src/PeerSessionResource.h @@ -43,6 +43,7 @@ #include "BtConstants.h" #include "PeerStat.h" #include "TimerA2.h" +#include "ExtensionMessageRegistry.h" namespace aria2 { @@ -73,7 +74,7 @@ private: // fast index set which localhost has sent to a peer. std::set amAllowedIndexSet_; bool extendedMessagingEnabled_; - Extensions extensions_; + ExtensionMessageRegistry extreg_; bool dhtEnabled_; PeerStat peerStat_; @@ -192,11 +193,11 @@ public: void extendedMessagingEnabled(bool b); - uint8_t getExtensionMessageID(const std::string& name) const; + uint8_t getExtensionMessageID(int key) const; - std::string getExtensionName(uint8_t id) const; + const char* getExtensionName(uint8_t id) const; - void addExtension(const std::string& name, uint8_t id); + void addExtension(int key, uint8_t id); bool dhtEnabled() const { diff --git a/src/UTMetadataRequestExtensionMessage.cc b/src/UTMetadataRequestExtensionMessage.cc index 04ace5a1..a12d088a 100644 --- a/src/UTMetadataRequestExtensionMessage.cc +++ b/src/UTMetadataRequestExtensionMessage.cc @@ -48,6 +48,7 @@ #include "DownloadContext.h" #include "BtMessage.h" #include "PieceStorage.h" +#include "ExtensionMessageRegistry.h" namespace aria2 { @@ -76,7 +77,8 @@ std::string UTMetadataRequestExtensionMessage::toString() const void UTMetadataRequestExtensionMessage::doReceivedAction() { SharedHandle attrs = bittorrent::getTorrentAttrs(dctx_); - uint8_t id = peer_->getExtensionMessageID("ut_metadata"); + uint8_t id = peer_->getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA); if(attrs->metadata.empty()) { SharedHandle m (new UTMetadataRejectExtensionMessage(id)); diff --git a/src/UTMetadataRequestFactory.cc b/src/UTMetadataRequestFactory.cc index 53d34e73..d7b2c8ad 100644 --- a/src/UTMetadataRequestFactory.cc +++ b/src/UTMetadataRequestFactory.cc @@ -44,6 +44,7 @@ #include "Logger.h" #include "LogFactory.h" #include "fmt.h" +#include "ExtensionMessageRegistry.h" namespace aria2 { @@ -71,7 +72,7 @@ void UTMetadataRequestFactory::create static_cast(p->getIndex()))); SharedHandle m (new UTMetadataRequestExtensionMessage - (peer_->getExtensionMessageID("ut_metadata"))); + (peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA))); m->setIndex(p->getIndex()); m->setDownloadContext(dctx_); m->setBtMessageDispatcher(dispatcher_); diff --git a/test/DefaultExtensionMessageFactoryTest.cc b/test/DefaultExtensionMessageFactoryTest.cc index 81de2036..10541878 100644 --- a/test/DefaultExtensionMessageFactoryTest.cc +++ b/test/DefaultExtensionMessageFactoryTest.cc @@ -53,7 +53,7 @@ public: peer_.reset(new Peer("192.168.0.1", 6969)); peer_->allocateSessionResource(1024, 1024*1024); - peer_->setExtension("ut_pex", 1); + peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 1); registry_.reset(new ExtensionMessageRegistry()); @@ -76,9 +76,9 @@ public: factory_->setDownloadContext(dctx_); } - std::string getExtensionMessageID(const std::string& name) + std::string getExtensionMessageID(int key) { - unsigned char id[1] = { registry_->getExtensionMessageID(name) }; + unsigned char id[1] = { registry_->getExtensionMessageID(key) }; return std::string(&id[0], &id[1]); } @@ -103,7 +103,7 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DefaultExtensionMessageFactoryTest); void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown() { - peer_->setExtension("foo", 255); + peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 255); unsigned char id[1] = { 255 }; @@ -139,7 +139,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex() bittorrent::packcompact(c3, "192.168.0.2", 6882); bittorrent::packcompact(c4, "10.1.1.3",10000); - std::string data = getExtensionMessageID("ut_pex")+"d5:added12:"+ + registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1); + + std::string data = getExtensionMessageID(ExtensionMessageRegistry::UT_PEX) + +"d5:added12:"+ std::string(&c1[0], &c1[6])+std::string(&c2[0], &c2[6])+ "7:added.f2:207:dropped12:"+ std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+ @@ -147,13 +150,17 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex() SharedHandle m = createMessage(data); - CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID("ut_pex"), + CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX), m->getExtensionMessageID()); } void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest() { - std::string data = getExtensionMessageID("ut_metadata")+ + registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1); + + std::string data = getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)+ "d8:msg_typei0e5:piecei1ee"; SharedHandle m = createMessage(data); @@ -162,7 +169,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest() void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData() { - std::string data = getExtensionMessageID("ut_metadata")+ + registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1); + + std::string data = getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)+ "d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000"; SharedHandle m = createMessage(data); @@ -173,7 +183,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData() void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject() { - std::string data = getExtensionMessageID("ut_metadata")+ + registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1); + + std::string data = getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)+ "d8:msg_typei2e5:piecei1ee"; SharedHandle m = createMessage(data); diff --git a/test/ExtensionMessageRegistryTest.cc b/test/ExtensionMessageRegistryTest.cc new file mode 100644 index 00000000..8791ce69 --- /dev/null +++ b/test/ExtensionMessageRegistryTest.cc @@ -0,0 +1,75 @@ +#include "ExtensionMessageRegistry.h" + +#include + +namespace aria2 { + +class ExtensionMessageRegistryTest:public CppUnit::TestFixture { + + CPPUNIT_TEST_SUITE(ExtensionMessageRegistryTest); + CPPUNIT_TEST(testStrBtExtension); + CPPUNIT_TEST(testKeyBtExtension); + CPPUNIT_TEST(testGetExtensionMessageID); + CPPUNIT_TEST_SUITE_END(); +public: + void testStrBtExtension(); + void testKeyBtExtension(); + void testGetExtensionMessageID(); +}; + +CPPUNIT_TEST_SUITE_REGISTRATION( ExtensionMessageRegistryTest ); + +void ExtensionMessageRegistryTest::testStrBtExtension() +{ + CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"), + std::string(strBtExtension + (ExtensionMessageRegistry::UT_PEX))); + CPPUNIT_ASSERT_EQUAL(std::string("ut_metadata"), + std::string(strBtExtension + (ExtensionMessageRegistry::UT_METADATA))); + CPPUNIT_ASSERT(!strBtExtension(100)); +} + +void ExtensionMessageRegistryTest::testKeyBtExtension() +{ + CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_PEX, + keyBtExtension("ut_pex")); + CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_METADATA, + keyBtExtension("ut_metadata")); + CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::MAX_EXTENSION, + keyBtExtension("unknown")); +} + +void ExtensionMessageRegistryTest::testGetExtensionMessageID() +{ + ExtensionMessageRegistry reg; + CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + CPPUNIT_ASSERT(!reg.getExtensionName(0)); + CPPUNIT_ASSERT(!reg.getExtensionName(1)); + CPPUNIT_ASSERT(!reg.getExtensionName(100)); + + reg.setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1); + + CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"), + std::string(reg.getExtensionName(1))); + CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + + reg.setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 127); + + CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)); + CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + + reg.removeExtension(ExtensionMessageRegistry::UT_PEX); + + CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)); + CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + CPPUNIT_ASSERT(!reg.getExtensionName(1)); +} + +} // namespace aria2 diff --git a/test/HandshakeExtensionMessageTest.cc b/test/HandshakeExtensionMessageTest.cc index 064a006e..eea01d21 100644 --- a/test/HandshakeExtensionMessageTest.cc +++ b/test/HandshakeExtensionMessageTest.cc @@ -60,12 +60,12 @@ void HandshakeExtensionMessageTest::testGetBencodedData() HandshakeExtensionMessage msg; msg.setClientVersion("aria2"); msg.setTCPPort(6889); - msg.setExtension("ut_pex", 1); - msg.setExtension("a2_dht", 2); + msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1); + msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2); msg.setMetadataSize(1024); CPPUNIT_ASSERT_EQUAL (std::string("d" - "1:md6:a2_dhti2e6:ut_pexi1ee" + "1:md11:ut_metadatai2e6:ut_pexi1ee" "13:metadata_sizei1024e" "1:pi6889e" "1:v5:aria2" @@ -81,12 +81,12 @@ void HandshakeExtensionMessageTest::testToString() HandshakeExtensionMessage msg; msg.setClientVersion("aria2"); msg.setTCPPort(6889); - msg.setExtension("ut_pex", 1); - msg.setExtension("a2_dht", 2); + msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1); + msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2); msg.setMetadataSize(1024); CPPUNIT_ASSERT_EQUAL (std::string("handshake client=aria2, tcpPort=6889, metadataSize=1024," - " a2_dht=2, ut_pex=1"), msg.toString()); + " ut_metadata=2, ut_pex=1"), msg.toString()); } void HandshakeExtensionMessageTest::testDoReceivedAction() @@ -106,9 +106,8 @@ void HandshakeExtensionMessageTest::testDoReceivedAction() HandshakeExtensionMessage msg; msg.setClientVersion("aria2"); msg.setTCPPort(6889); - msg.setExtension("ut_pex", 1); - msg.setExtension("a2_dht", 2); - msg.setExtension("ut_metadata", 3); + msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1); + msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 3); msg.setMetadataSize(1024); msg.setPeer(peer); msg.setDownloadContext(dctx); @@ -116,8 +115,12 @@ void HandshakeExtensionMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort()); - CPPUNIT_ASSERT_EQUAL((uint8_t)1, peer->getExtensionMessageID("ut_pex")); - CPPUNIT_ASSERT_EQUAL((uint8_t)2, peer->getExtensionMessageID("a2_dht")); + CPPUNIT_ASSERT_EQUAL((uint8_t)1, + peer->getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + CPPUNIT_ASSERT_EQUAL((uint8_t)3, + peer->getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)); CPPUNIT_ASSERT(peer->isSeeder()); CPPUNIT_ASSERT_EQUAL((size_t)1024, attrs->metadataSize); CPPUNIT_ASSERT_EQUAL((int64_t)1024, dctx->getTotalLength()); @@ -134,13 +137,15 @@ void HandshakeExtensionMessageTest::testDoReceivedAction() void HandshakeExtensionMessageTest::testCreate() { std::string in = - "0d1:pi6881e1:v5:aria21:md6:ut_pexi1ee13:metadata_sizei1024ee"; + "0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee"; SharedHandle m = HandshakeExtensionMessage::create(reinterpret_cast(in.c_str()), in.size()); CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion()); CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort()); - CPPUNIT_ASSERT_EQUAL((uint8_t)1, m->getExtensionMessageID("ut_pex")); + CPPUNIT_ASSERT_EQUAL((uint8_t)1, + m->getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); CPPUNIT_ASSERT_EQUAL((size_t)1024, m->getMetadataSize()); try { // bad payload format @@ -182,7 +187,9 @@ void HandshakeExtensionMessageTest::testCreate_stringnum() // port number in string is not allowed CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort()); // extension ID in string is not allowed - CPPUNIT_ASSERT_EQUAL((uint8_t)0, m->getExtensionMessageID("ut_pex")); + CPPUNIT_ASSERT_EQUAL((uint8_t)0, + m->getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); } } // namespace aria2 diff --git a/test/Makefile.am b/test/Makefile.am index 1fb52ae4..ed55ccba 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -208,7 +208,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\ LpdMessageReceiverTest.cc\ Bencode2Test.cc\ PeerConnectionTest.cc\ - ValueBaseBencodeParserTest.cc + ValueBaseBencodeParserTest.cc\ + ExtensionMessageRegistryTest.cc endif # ENABLE_BITTORRENT if ENABLE_METALINK diff --git a/test/PeerSessionResourceTest.cc b/test/PeerSessionResourceTest.cc index 8a6d350b..40930871 100644 --- a/test/PeerSessionResourceTest.cc +++ b/test/PeerSessionResourceTest.cc @@ -137,12 +137,17 @@ void PeerSessionResourceTest::testGetExtensionMessageID() { PeerSessionResource res(1024, 1024*1024); - res.addExtension("a2", 9); - CPPUNIT_ASSERT_EQUAL((uint8_t)9, res.getExtensionMessageID("a2")); - CPPUNIT_ASSERT_EQUAL((uint8_t)0, res.getExtensionMessageID("non")); + res.addExtension(ExtensionMessageRegistry::UT_PEX, 9); + CPPUNIT_ASSERT_EQUAL((uint8_t)9, + res.getExtensionMessageID + (ExtensionMessageRegistry::UT_PEX)); + CPPUNIT_ASSERT_EQUAL((uint8_t)0, + res.getExtensionMessageID + (ExtensionMessageRegistry::UT_METADATA)); - CPPUNIT_ASSERT_EQUAL(std::string("a2"), res.getExtensionName(9)); - CPPUNIT_ASSERT_EQUAL(std::string(""), res.getExtensionName(10)); + CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"), + std::string(res.getExtensionName(9))); + CPPUNIT_ASSERT(!res.getExtensionName(10)); } void PeerSessionResourceTest::testFastExtensionEnabled() diff --git a/test/UTMetadataRequestExtensionMessageTest.cc b/test/UTMetadataRequestExtensionMessageTest.cc index 34a07f8e..27dd391d 100644 --- a/test/UTMetadataRequestExtensionMessageTest.cc +++ b/test/UTMetadataRequestExtensionMessageTest.cc @@ -16,6 +16,7 @@ #include "PieceStorage.h" #include "extension_message_test_helper.h" #include "DlAbortEx.h" +#include "ExtensionMessageRegistry.h" namespace aria2 { @@ -44,7 +45,7 @@ public: dctx_->setAttribute(CTX_ATTR_BT, attrs); peer_.reset(new Peer("host", 6880)); peer_->allocateSessionResource(0, 0); - peer_->setExtension("ut_metadata", 1); + peer_->setExtension(ExtensionMessageRegistry::UT_METADATA, 1); } template