Support UDP tracker

It shares UDP listening port with IPv4 DHT. At the moment, in order to
enable UDP tracker support, enable IPv4 DHT.
This commit is contained in:
Tatsuhiro Tsujikawa 2013-02-25 00:56:49 +09:00
parent b782a56b1c
commit d68741697a
29 changed files with 2271 additions and 102 deletions

View file

@ -40,9 +40,12 @@
#include <string>
#include "a2time.h"
#include "SharedHandle.h"
namespace aria2 {
class UDPTrackerRequest;
class BtAnnounce {
public:
virtual ~BtAnnounce() {}
@ -65,6 +68,10 @@ public:
*/
virtual std::string getAnnounceUrl() = 0;
virtual SharedHandle<UDPTrackerRequest>
createUDPTrackerRequest(const std::string& remoteAddr, uint16_t remotePort,
uint16_t localPort) = 0;
/**
* Tells that the announce process has just started.
*/
@ -96,6 +103,9 @@ public:
virtual void processAnnounceResponse(const unsigned char* trackerResponse,
size_t trackerResponseLength) = 0;
virtual void processUDPTrackerResponse
(const SharedHandle<UDPTrackerRequest>& req) = 0;
/**
* Returns true if no more announce is needed.
*/

View file

@ -42,12 +42,14 @@
#include "BtProgressInfoFile.h"
#include "bittorrent_helper.h"
#include "LpdMessageReceiver.h"
#include "UDPTrackerClient.h"
#include "NullHandle.h"
namespace aria2 {
BtRegistry::BtRegistry()
: tcpPort_(0)
: tcpPort_(0),
udpPort_(0)
{}
BtRegistry::~BtRegistry() {}
@ -107,6 +109,12 @@ void BtRegistry::setLpdMessageReceiver
lpdMessageReceiver_ = receiver;
}
void BtRegistry::setUDPTrackerClient
(const SharedHandle<UDPTrackerClient>& tracker)
{
udpTrackerClient_ = tracker;
}
BtObject::BtObject
(const SharedHandle<DownloadContext>& downloadContext,
const SharedHandle<PieceStorage>& pieceStorage,

View file

@ -51,6 +51,7 @@ class BtRuntime;
class BtProgressInfoFile;
class DownloadContext;
class LpdMessageReceiver;
class UDPTrackerClient;
struct BtObject {
SharedHandle<DownloadContext> downloadContext;
@ -80,7 +81,10 @@ class BtRegistry {
private:
std::map<a2_gid_t, SharedHandle<BtObject> > pool_;
uint16_t tcpPort_;
// This is IPv4 port for DHT and UDP tracker. No IPv6 udpPort atm.
uint16_t udpPort_;
SharedHandle<LpdMessageReceiver> lpdMessageReceiver_;
SharedHandle<UDPTrackerClient> udpTrackerClient_;
public:
BtRegistry();
~BtRegistry();
@ -118,11 +122,26 @@ public:
return tcpPort_;
}
void setUdpPort(uint16_t port)
{
udpPort_ = port;
}
uint16_t getUdpPort() const
{
return udpPort_;
}
void setLpdMessageReceiver(const SharedHandle<LpdMessageReceiver>& receiver);
const SharedHandle<LpdMessageReceiver>& getLpdMessageReceiver() const
{
return lpdMessageReceiver_;
}
void setUDPTrackerClient(const SharedHandle<UDPTrackerClient>& tracker);
const SharedHandle<UDPTrackerClient>& getUDPTrackerClient() const
{
return udpTrackerClient_;
}
};
} // namespace aria2

View file

@ -67,6 +67,7 @@
#include "DHTMessageReceiver.h"
#include "DHTMessageFactory.h"
#include "DHTMessageCallback.h"
#include "UDPTrackerClient.h"
#include "BtProgressInfoFile.h"
#include "BtAnnounce.h"
#include "BtRuntime.h"

View file

@ -46,10 +46,16 @@
#include "LogFactory.h"
#include "DHTMessageCallback.h"
#include "DHTNode.h"
#include "DHTConnection.h"
#include "UDPTrackerClient.h"
#include "UDPTrackerRequest.h"
#include "fmt.h"
#include "wallclock.h"
namespace aria2 {
// TODO This name of this command is misleading, because now it also
// handles UDP trackers as well as DHT.
DHTInteractionCommand::DHTInteractionCommand(cuid_t cuid, DownloadEngine* e)
: Command(cuid),
e_(e)
@ -77,23 +83,60 @@ void DHTInteractionCommand::disableReadCheckSocket(const SharedHandle<SocketCore
bool DHTInteractionCommand::execute()
{
if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) {
// We need to keep this command alive while TrackerWatcherCommand
// needs this.
if(e_->getRequestGroupMan()->downloadFinished() ||
(e_->isHaltRequested() && udpTrackerClient_->getNumWatchers() == 0)) {
return true;
} else if(e_->isForceHaltRequested()) {
udpTrackerClient_->failAll();
return true;
}
taskQueue_->executeTask();
std::string remoteAddr;
uint16_t remotePort;
unsigned char data[64*1024];
try {
while(1) {
SharedHandle<DHTMessage> m = receiver_->receiveMessage();
if(!m) {
ssize_t length = connection_->receiveMessage(data, sizeof(data),
remoteAddr, remotePort);
if(length <= 0) {
break;
}
if(data[0] == 'd') {
// udp tracker response does not start with 'd', so assume
// this message belongs to DHT. nothrow.
receiver_->receiveMessage(remoteAddr, remotePort, data, length);
} else {
// this may be udp tracker response. nothrow.
udpTrackerClient_->receiveReply(data, length, remoteAddr, remotePort,
global::wallclock());
}
}
} catch(RecoverableException& e) {
A2_LOG_INFO_EX("Exception thrown while receiving UDP message.", e);
}
receiver_->handleTimeout();
try {
udpTrackerClient_->handleTimeout(global::wallclock());
dispatcher_->sendMessages();
while(!udpTrackerClient_->getPendingRequests().empty()) {
// no throw
ssize_t length = udpTrackerClient_->createRequest(data, sizeof(data),
remoteAddr, remotePort,
global::wallclock());
if(length == -1) {
break;
}
try {
// throw
connection_->sendMessage(data, length, remoteAddr, remotePort);
udpTrackerClient_->requestSent(global::wallclock());
} catch(RecoverableException& e) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e);
A2_LOG_INFO_EX("Exception thrown while sending UDP tracker request.", e);
udpTrackerClient_->requestFail(UDPT_ERR_NETWORK);
}
}
e_->addCommand(this);
return false;
@ -114,4 +157,16 @@ void DHTInteractionCommand::setTaskQueue(const SharedHandle<DHTTaskQueue>& taskQ
taskQueue_ = taskQueue;
}
void DHTInteractionCommand::setConnection
(const SharedHandle<DHTConnection>& connection)
{
connection_ = connection;
}
void DHTInteractionCommand::setUDPTrackerClient
(const SharedHandle<UDPTrackerClient>& udpTrackerClient)
{
udpTrackerClient_ = udpTrackerClient;
}
} // namespace aria2

View file

@ -45,6 +45,8 @@ class DHTMessageReceiver;
class DHTTaskQueue;
class DownloadEngine;
class SocketCore;
class DHTConnection;
class UDPTrackerClient;
class DHTInteractionCommand:public Command {
private:
@ -53,6 +55,8 @@ private:
SharedHandle<DHTMessageReceiver> receiver_;
SharedHandle<DHTTaskQueue> taskQueue_;
SharedHandle<SocketCore> readCheckSocket_;
SharedHandle<DHTConnection> connection_;
SharedHandle<UDPTrackerClient> udpTrackerClient_;
public:
DHTInteractionCommand(cuid_t cuid, DownloadEngine* e);
@ -69,6 +73,11 @@ public:
void setMessageReceiver(const SharedHandle<DHTMessageReceiver>& receiver);
void setTaskQueue(const SharedHandle<DHTTaskQueue>& taskQueue);
void setConnection(const SharedHandle<DHTConnection>& connection);
void setUDPTrackerClient
(const SharedHandle<UDPTrackerClient>& udpTrackerClient);
};
} // namespace aria2

View file

@ -63,18 +63,11 @@ DHTMessageReceiver::DHTMessageReceiver
DHTMessageReceiver::~DHTMessageReceiver() {}
SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage()
SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length)
{
std::string remoteAddr;
uint16_t remotePort;
unsigned char data[64*1024];
try {
ssize_t length = connection_->receiveMessage(data, sizeof(data),
remoteAddr,
remotePort);
if(length <= 0) {
return SharedHandle<DHTMessage>();
}
bool isReply = false;
SharedHandle<ValueBase> decoded = bencode2::decode(data, length);
const Dict* dict = downcast<Dict>(decoded);
@ -87,13 +80,13 @@ SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage()
} else {
A2_LOG_INFO(fmt("Malformed DHT message. Missing 'y' key. From:%s:%u",
remoteAddr.c_str(), remotePort));
return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
} else {
A2_LOG_INFO(fmt("Malformed DHT message. This is not a bencoded directory."
" From:%s:%u",
remoteAddr.c_str(), remotePort));
return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
if(isReply) {
std::pair<SharedHandle<DHTResponseMessage>,
@ -101,7 +94,7 @@ SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage()
tracker_->messageArrived(dict, remoteAddr, remotePort);
if(!p.first) {
// timeout or malicious? message
return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
onMessageReceived(p.first);
if(p.second) {
@ -114,14 +107,14 @@ SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage()
if(*message->getLocalNode() == *message->getRemoteNode()) {
// drop message from localnode
A2_LOG_INFO("Received DHT message from localnode.");
return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
onMessageReceived(message);
return message;
}
} catch(RecoverableException& e) {
A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e);
return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
}

View file

@ -69,7 +69,9 @@ public:
~DHTMessageReceiver();
SharedHandle<DHTMessage> receiveMessage();
SharedHandle<DHTMessage> receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length);
void handleTimeout();

View file

@ -61,6 +61,8 @@
#include "DHTRegistry.h"
#include "DHTBucketRefreshTask.h"
#include "DHTMessageCallback.h"
#include "UDPTrackerClient.h"
#include "BtRegistry.h"
#include "prefs.h"
#include "Option.h"
#include "SocketCore.h"
@ -176,6 +178,8 @@ void DHTSetup::setup
factory->setTokenTracker(tokenTracker.get());
factory->setLocalNode(localNode);
// For now, UDPTrackerClient was enabled along with DHT
SharedHandle<UDPTrackerClient> udpTrackerClient(new UDPTrackerClient());
// assign them into DHTRegistry
if(family == AF_INET) {
DHTRegistry::getMutableData().localNode = localNode;
@ -187,6 +191,8 @@ void DHTSetup::setup
DHTRegistry::getMutableData().messageDispatcher = dispatcher;
DHTRegistry::getMutableData().messageReceiver = receiver;
DHTRegistry::getMutableData().messageFactory = factory;
e->getBtRegistry()->setUDPTrackerClient(udpTrackerClient);
e->getBtRegistry()->setUdpPort(localNode->getPort());
} else {
DHTRegistry::getMutableData6().localNode = localNode;
DHTRegistry::getMutableData6().routingTable = routingTable;
@ -244,6 +250,8 @@ void DHTSetup::setup
command->setMessageReceiver(receiver);
command->setTaskQueue(taskQueue);
command->setReadCheckSocket(connection->getSocket());
command->setConnection(connection);
command->setUDPTrackerClient(udpTrackerClient);
tempCommands->push_back(command);
}
{
@ -282,12 +290,15 @@ void DHTSetup::setup
}
commands.insert(commands.end(), tempCommands->begin(), tempCommands->end());
tempCommands->clear();
} catch(RecoverableException& e) {
} catch(RecoverableException& ex) {
A2_LOG_ERROR_EX(fmt("Exception caught while initializing DHT functionality."
" DHT is disabled."),
e);
ex);
if(family == AF_INET) {
DHTRegistry::clearData();
e->getBtRegistry()->setUDPTrackerClient
(SharedHandle<UDPTrackerClient>());
e->getBtRegistry()->setUdpPort(0);
} else {
DHTRegistry::clearData6();
}

View file

@ -52,6 +52,8 @@
#include "bittorrent_helper.h"
#include "wallclock.h"
#include "uri.h"
#include "UDPTrackerRequest.h"
#include "SocketCore.h"
namespace aria2 {
@ -115,7 +117,7 @@ bool uriHasQuery(const std::string& uri)
}
} // namespace
std::string DefaultBtAnnounce::getAnnounceUrl() {
bool DefaultBtAnnounce::adjustAnnounceList() {
if(isStoppedAnnounceReady()) {
if(!announceList_.currentTierAcceptsStoppedEvent()) {
announceList_.moveToStoppedAllowedTier();
@ -135,6 +137,13 @@ std::string DefaultBtAnnounce::getAnnounceUrl() {
announceList_.setEvent(AnnounceTier::STARTED_AFTER_COMPLETION);
}
} else {
return false;
}
return true;
}
std::string DefaultBtAnnounce::getAnnounceUrl() {
if(!adjustAnnounceList()) {
return A2STR::NIL;
}
int numWant = 50;
@ -193,6 +202,60 @@ std::string DefaultBtAnnounce::getAnnounceUrl() {
return uri;
}
SharedHandle<UDPTrackerRequest> DefaultBtAnnounce::createUDPTrackerRequest
(const std::string& remoteAddr, uint16_t remotePort, uint16_t localPort)
{
if(!adjustAnnounceList()) {
return SharedHandle<UDPTrackerRequest>();
}
NetStat& stat = downloadContext_->getNetStat();
int64_t left =
pieceStorage_->getTotalLength()-pieceStorage_->getCompletedLength();
SharedHandle<UDPTrackerRequest> req(new UDPTrackerRequest());
req->remoteAddr = remoteAddr;
req->remotePort = remotePort;
req->action = UDPT_ACT_ANNOUNCE;
req->infohash = bittorrent::getTorrentAttrs(downloadContext_)->infoHash;
const unsigned char* peerId = bittorrent::getStaticPeerId();
req->peerId.assign(peerId, peerId + PEER_ID_LENGTH);
req->downloaded = stat.getSessionDownloadLength();
req->left = left;
req->uploaded = stat.getSessionUploadLength();
switch(announceList_.getEvent()) {
case AnnounceTier::STARTED:
case AnnounceTier::STARTED_AFTER_COMPLETION:
req->event = UDPT_EVT_STARTED;
break;
case AnnounceTier::STOPPED:
req->event = UDPT_EVT_STOPPED;
break;
case AnnounceTier::COMPLETED:
req->event = UDPT_EVT_COMPLETED;
break;
default:
req->event = 0;
}
if(!option_->blank(PREF_BT_EXTERNAL_IP)) {
unsigned char dest[16];
if(net::getBinAddr(dest, option_->get(PREF_BT_EXTERNAL_IP)) == 4) {
memcpy(&req->ip, dest, 4);
} else {
req->ip = 0;
}
} else {
req->ip = 0;
}
req->key = randomizer_->getRandomNumber(INT32_MAX);
int numWant = 50;
if(!btRuntime_->lessThanMinPeers() || btRuntime_->isHalt()) {
numWant = 0;
}
req->numWant = numWant;
req->port = localPort;
req->extensions = 0;
return req;
}
void DefaultBtAnnounce::announceStart() {
++trackers_;
}
@ -287,6 +350,30 @@ DefaultBtAnnounce::processAnnounceResponse(const unsigned char* trackerResponse,
}
}
void DefaultBtAnnounce::processUDPTrackerResponse
(const SharedHandle<UDPTrackerRequest>& req)
{
const SharedHandle<UDPTrackerReply>& reply = req->reply;
A2_LOG_DEBUG("Now processing UDP tracker response.");
if(reply->interval > 0) {
minInterval_ = reply->interval;
A2_LOG_DEBUG(fmt("Min interval:%ld", static_cast<long int>(minInterval_)));
interval_ = minInterval_;
}
complete_ = reply->seeders;
A2_LOG_DEBUG(fmt("Complete:%d", reply->seeders));
incomplete_ = reply->leechers;
A2_LOG_DEBUG(fmt("Incomplete:%d", reply->leechers));
if(!btRuntime_->isHalt() && btRuntime_->lessThanMinPeers()) {
for(std::vector<std::pair<std::string, uint16_t> >::iterator i =
reply->peers.begin(), eoi = reply->peers.end(); i != eoi;
++i) {
peerStorage_->addPeer(SharedHandle<Peer>(new Peer((*i).first,
(*i).second)));
}
}
}
bool DefaultBtAnnounce::noMoreAnnounce() {
return (trackers_ == 0 &&
btRuntime_->isHalt() &&

View file

@ -66,6 +66,8 @@ private:
SharedHandle<PieceStorage> pieceStorage_;
SharedHandle<PeerStorage> peerStorage_;
uint16_t tcpPort_;
bool adjustAnnounceList();
public:
DefaultBtAnnounce(const SharedHandle<DownloadContext>& downloadContext,
const Option* option);
@ -103,6 +105,10 @@ public:
virtual std::string getAnnounceUrl();
virtual SharedHandle<UDPTrackerRequest>
createUDPTrackerRequest(const std::string& remoteAddr, uint16_t remotePort,
uint16_t localPort);
virtual void announceStart();
virtual void announceSuccess();
@ -116,6 +122,9 @@ public:
virtual void processAnnounceResponse(const unsigned char* trackerResponse,
size_t trackerResponseLength);
virtual void processUDPTrackerResponse
(const SharedHandle<UDPTrackerRequest>& req);
virtual bool noMoreAnnounce();
virtual void shuffleAnnounce();

View file

@ -85,7 +85,7 @@ volatile sig_atomic_t globalHaltRequested = 0;
DownloadEngine::DownloadEngine(const SharedHandle<EventPoll>& eventPoll)
: eventPoll_(eventPoll),
haltRequested_(false),
haltRequested_(0),
noWait_(false),
refreshInterval_(DEFAULT_REFRESH_INTERVAL),
cookieStorage_(new CookieStorage()),
@ -239,13 +239,13 @@ void DownloadEngine::afterEachIteration()
void DownloadEngine::requestHalt()
{
haltRequested_ = true;
haltRequested_ = std::max(haltRequested_, 1);
requestGroupMan_->halt();
}
void DownloadEngine::requestForceHalt()
{
haltRequested_ = true;
haltRequested_ = std::max(haltRequested_, 2);
requestGroupMan_->forceHalt();
}

View file

@ -79,7 +79,7 @@ private:
SharedHandle<StatCalc> statCalc_;
bool haltRequested_;
int haltRequested_;
class SocketPoolEntry {
private:
@ -230,6 +230,11 @@ public:
return haltRequested_;
}
bool isForceHaltRequested() const
{
return haltRequested_ >= 2;
}
void requestHalt();
void requestForceHalt();

View file

@ -522,7 +522,10 @@ SRCS += PeerAbstractCommand.cc PeerAbstractCommand.h\
ValueBaseBencodeParser.h\
BencodeDiskWriter.h\
BencodeDiskWriterFactory.h\
MemoryBencodePreDownloadHandler.h
MemoryBencodePreDownloadHandler.h\
UDPTrackerClient.cc UDPTrackerClient.h\
UDPTrackerRequest.cc UDPTrackerRequest.h\
NameResolveCommand.cc NameResolveCommand.h
endif # ENABLE_BITTORRENT
if ENABLE_METALINK

193
src/NameResolveCommand.cc Normal file
View file

@ -0,0 +1,193 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "NameResolveCommand.h"
#include "DownloadEngine.h"
#include "NameResolver.h"
#include "prefs.h"
#include "message.h"
#include "util.h"
#include "Option.h"
#include "RequestGroupMan.h"
#include "Logger.h"
#include "LogFactory.h"
#include "fmt.h"
#include "UDPTrackerRequest.h"
#include "UDPTrackerClient.h"
#include "BtRegistry.h"
#ifdef ENABLE_ASYNC_DNS
#include "AsyncNameResolver.h"
#endif // ENABLE_ASYNC_DNS
namespace aria2 {
NameResolveCommand::NameResolveCommand
(cuid_t cuid, DownloadEngine* e,
const SharedHandle<UDPTrackerRequest>& req)
: Command(cuid),
e_(e),
req_(req)
{
setStatus(Command::STATUS_ONESHOT_REALTIME);
}
NameResolveCommand::~NameResolveCommand()
{
#ifdef ENABLE_ASYNC_DNS
disableNameResolverCheck(resolver_);
#endif // ENABLE_ASYNC_DNS
}
bool NameResolveCommand::execute()
{
// This is UDP tracker specific, but we need to keep this command
// alive until force shutdown is
// commencing. RequestGroupMan::downloadFinished() is useless here
// at the moment.
if(e_->isForceHaltRequested()) {
onShutdown();
return true;
}
#ifdef ENABLE_ASYNC_DNS
if(!resolver_) {
int family = AF_INET;
resolver_.reset(new AsyncNameResolver(family
#ifdef HAVE_ARES_ADDR_NODE
, e_->getAsyncDNSServers()
#endif // HAVE_ARES_ADDR_NODE
));
}
#endif // ENABLE_ASYNC_DNS
std::string hostname = req_->remoteAddr;
std::vector<std::string> res;
if(util::isNumericHost(hostname)) {
res.push_back(hostname);
} else {
#ifdef ENABLE_ASYNC_DNS
if(e_->getOption()->getAsBool(PREF_ASYNC_DNS)) {
try {
if(resolveHostname(hostname, resolver_)) {
res = resolver_->getResolvedAddresses();
} else {
e_->addCommand(this);
return false;
}
} catch(RecoverableException& e) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e);
}
} else
#endif // ENABLE_ASYNC_DNS
{
NameResolver resolver;
resolver.setSocktype(SOCK_DGRAM);
try {
resolver.resolve(res, hostname);
} catch(RecoverableException& e) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e);
}
}
}
if(res.empty()) {
onFailure();
} else {
onSuccess(res, e_);
}
return true;
}
void NameResolveCommand::onShutdown()
{
req_->state = UDPT_STA_COMPLETE;
req_->error = UDPT_ERR_SHUTDOWN;
}
void NameResolveCommand::onFailure()
{
req_->state = UDPT_STA_COMPLETE;
req_->error = UDPT_ERR_NETWORK;
}
void NameResolveCommand::onSuccess
(const std::vector<std::string>& addrs, DownloadEngine* e)
{
req_->remoteAddr = addrs[0];
e->getBtRegistry()->getUDPTrackerClient()->addRequest(req_);
}
#ifdef ENABLE_ASYNC_DNS
bool NameResolveCommand::resolveHostname
(const std::string& hostname,
const SharedHandle<AsyncNameResolver>& resolver)
{
switch(resolver->getStatus()) {
case AsyncNameResolver::STATUS_READY:
A2_LOG_INFO(fmt(MSG_RESOLVING_HOSTNAME,
getCuid(),
hostname.c_str()));
resolver->resolve(hostname);
setNameResolverCheck(resolver);
return false;
case AsyncNameResolver::STATUS_SUCCESS:
A2_LOG_INFO(fmt(MSG_NAME_RESOLUTION_COMPLETE,
getCuid(),
resolver->getHostname().c_str(),
resolver->getResolvedAddresses().front().c_str()));
return true;
break;
case AsyncNameResolver::STATUS_ERROR:
throw DL_ABORT_EX
(fmt(MSG_NAME_RESOLUTION_FAILED,
getCuid(),
hostname.c_str(),
resolver->getError().c_str()));
default:
return false;
}
}
void NameResolveCommand::setNameResolverCheck
(const SharedHandle<AsyncNameResolver>& resolver)
{
e_->addNameResolverCheck(resolver, this);
}
void NameResolveCommand::disableNameResolverCheck
(const SharedHandle<AsyncNameResolver>& resolver)
{
e_->deleteNameResolverCheck(resolver, this);
}
#endif // ENABLE_ASYNC_DNS
} // namespace aria2

87
src/NameResolveCommand.h Normal file
View file

@ -0,0 +1,87 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_NAME_RESOLVE_COMMAND_H
#define D_NAME_RESOLVE_COMMAND_H
#include "Command.h"
#include <string>
#include <vector>
#include "SharedHandle.h"
// TODO Make this class generic.
namespace aria2 {
class DownloadEngine;
#ifdef ENABLE_ASYNC_DNS
class AsyncNameResolver;
#endif // ENABLE_ASYNC_DNS
class UDPTrackerRequest;
class NameResolveCommand:public Command {
private:
DownloadEngine* e_;
#ifdef ENABLE_ASYNC_DNS
SharedHandle<AsyncNameResolver> resolver_;
#endif // ENABLE_ASYNC_DNS
#ifdef ENABLE_ASYNC_DNS
bool resolveHostname(const std::string& hostname,
const SharedHandle<AsyncNameResolver>& resolver);
void setNameResolverCheck(const SharedHandle<AsyncNameResolver>& resolver);
void disableNameResolverCheck(const SharedHandle<AsyncNameResolver>& resolver);
#endif // ENABLE_ASYNC_DNS
SharedHandle<UDPTrackerRequest> req_;
void onShutdown();
void onFailure();
void onSuccess
(const std::vector<std::string>& addrs, DownloadEngine* e);
public:
NameResolveCommand(cuid_t cuid, DownloadEngine* e,
const SharedHandle<UDPTrackerRequest>& req);
virtual ~NameResolveCommand();
virtual bool execute();
};
} // namespace aria2
#endif // D_NAME_RESOVE_COMMAND_H

View file

@ -63,32 +63,156 @@
#include "a2functional.h"
#include "util.h"
#include "fmt.h"
#include "UDPTrackerRequest.h"
#include "UDPTrackerClient.h"
#include "BtRegistry.h"
#include "NameResolveCommand.h"
namespace aria2 {
HTTPAnnRequest::HTTPAnnRequest(const SharedHandle<RequestGroup>& rg)
: rg_(rg)
{}
HTTPAnnRequest::~HTTPAnnRequest()
{}
bool HTTPAnnRequest::stopped() const
{
return rg_->getNumCommand() == 0;
}
bool HTTPAnnRequest::success() const
{
return rg_->downloadFinished();
}
void HTTPAnnRequest::stop(DownloadEngine* e)
{
rg_->setForceHaltRequested(true);
}
bool HTTPAnnRequest::issue(DownloadEngine* e)
{
try {
std::vector<Command*>* commands = new std::vector<Command*>();
auto_delete_container<std::vector<Command*> > commandsDel(commands);
rg_->createInitialCommand(*commands, e);
e->addCommand(*commands);
e->setNoWait(true);
commands->clear();
A2_LOG_DEBUG("added tracker request command");
return true;
} catch(RecoverableException& ex) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex);
return false;
}
}
bool HTTPAnnRequest::processResponse
(const SharedHandle<BtAnnounce>& btAnnounce)
{
try {
std::stringstream strm;
unsigned char data[2048];
rg_->getPieceStorage()->getDiskAdaptor()->openFile();
while(1) {
ssize_t dataLength = rg_->getPieceStorage()->
getDiskAdaptor()->readData(data, sizeof(data), strm.tellp());
if(dataLength == 0) {
break;
}
strm.write(reinterpret_cast<const char*>(data), dataLength);
}
std::string res = strm.str();
btAnnounce->processAnnounceResponse
(reinterpret_cast<const unsigned char*>(res.c_str()), res.size());
return true;
} catch(RecoverableException& e) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, e);
return false;
}
}
UDPAnnRequest::UDPAnnRequest(const SharedHandle<UDPTrackerRequest>& req)
: req_(req)
{}
UDPAnnRequest::~UDPAnnRequest()
{}
bool UDPAnnRequest::stopped() const
{
return !req_ || req_->state == UDPT_STA_COMPLETE;
}
bool UDPAnnRequest::success() const
{
return req_ && req_->state == UDPT_STA_COMPLETE &&
req_->error == UDPT_ERR_SUCCESS;
}
void UDPAnnRequest::stop(DownloadEngine* e)
{
if(req_) {
req_.reset();
}
}
bool UDPAnnRequest::issue(DownloadEngine* e)
{
if(req_) {
NameResolveCommand* command = new NameResolveCommand
(e->newCUID(), e, req_);
e->addCommand(command);
e->setNoWait(true);
return true;
} else {
return false;
}
}
bool UDPAnnRequest::processResponse
(const SharedHandle<BtAnnounce>& btAnnounce)
{
if(req_) {
btAnnounce->processUDPTrackerResponse(req_);
return true;
} else {
return false;
}
}
TrackerWatcherCommand::TrackerWatcherCommand
(cuid_t cuid, RequestGroup* requestGroup, DownloadEngine* e)
: Command(cuid),
requestGroup_(requestGroup),
e_(e)
e_(e),
udpTrackerClient_(e_->getBtRegistry()->getUDPTrackerClient())
{
requestGroup_->increaseNumCommand();
if(udpTrackerClient_) {
udpTrackerClient_->increaseWatchers();
}
}
TrackerWatcherCommand::~TrackerWatcherCommand()
{
requestGroup_->decreaseNumCommand();
if(udpTrackerClient_) {
udpTrackerClient_->decreaseWatchers();
}
}
bool TrackerWatcherCommand::execute() {
if(requestGroup_->isForceHaltRequested()) {
if(!trackerRequestGroup_) {
if(!trackerRequest_) {
return true;
} else if(trackerRequestGroup_->getNumCommand() == 0 ||
trackerRequestGroup_->downloadFinished()) {
} else if(trackerRequest_->stopped() ||
trackerRequest_->success()) {
return true;
} else {
trackerRequestGroup_->setForceHaltRequested(true);
trackerRequest_->stop(e_);
e_->setRefreshInterval(0);
e_->addCommand(this);
return false;
@ -98,44 +222,32 @@ bool TrackerWatcherCommand::execute() {
A2_LOG_DEBUG("no more announce");
return true;
}
if(!trackerRequestGroup_) {
trackerRequestGroup_ = createAnnounce();
if(trackerRequestGroup_) {
try {
std::vector<Command*>* commands = new std::vector<Command*>();
auto_delete_container<std::vector<Command*> > commandsDel(commands);
trackerRequestGroup_->createInitialCommand(*commands, e_);
e_->addCommand(*commands);
commands->clear();
A2_LOG_DEBUG("added tracker request command");
} catch(RecoverableException& ex) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex);
if(!trackerRequest_) {
trackerRequest_ = createAnnounce(e_);
if(trackerRequest_) {
trackerRequest_->issue(e_);
}
}
} else if(trackerRequestGroup_->getNumCommand() == 0) {
} else if(trackerRequest_->stopped()) {
// We really want to make sure that tracker request has finished
// by checking getNumCommand() == 0. Because we reset
// trackerRequestGroup_, if it is still used in other Command, we
// will get Segmentation fault.
if(trackerRequestGroup_->downloadFinished()) {
try {
std::string trackerResponse = getTrackerResponse(trackerRequestGroup_);
processTrackerResponse(trackerResponse);
if(trackerRequest_->success()) {
if(trackerRequest_->processResponse(btAnnounce_)) {
btAnnounce_->announceSuccess();
btAnnounce_->resetAnnounce();
} catch(RecoverableException& ex) {
A2_LOG_ERROR_EX(EX_EXCEPTION_CAUGHT, ex);
addConnection();
} else {
btAnnounce_->announceFailure();
if(btAnnounce_->isAllAnnounceFailed()) {
btAnnounce_->resetAnnounce();
}
}
trackerRequestGroup_.reset();
trackerRequest_.reset();
} else {
// handle errors here
btAnnounce_->announceFailure(); // inside it, trackers = 0.
trackerRequestGroup_.reset();
trackerRequest_.reset();
if(btAnnounce_->isAllAnnounceFailed()) {
btAnnounce_->resetAnnounce();
}
@ -145,30 +257,8 @@ bool TrackerWatcherCommand::execute() {
return false;
}
std::string TrackerWatcherCommand::getTrackerResponse
(const SharedHandle<RequestGroup>& requestGroup)
void TrackerWatcherCommand::addConnection()
{
std::stringstream strm;
unsigned char data[2048];
requestGroup->getPieceStorage()->getDiskAdaptor()->openFile();
while(1) {
ssize_t dataLength = requestGroup->getPieceStorage()->
getDiskAdaptor()->readData(data, sizeof(data), strm.tellp());
if(dataLength == 0) {
break;
}
strm.write(reinterpret_cast<const char*>(data), dataLength);
}
return strm.str();
}
// TODO we have to deal with the exception thrown By BtAnnounce
void TrackerWatcherCommand::processTrackerResponse
(const std::string& trackerResponse)
{
btAnnounce_->processAnnounceResponse
(reinterpret_cast<const unsigned char*>(trackerResponse.c_str()),
trackerResponse.size());
while(!btRuntime_->isHalt() && btRuntime_->lessThanMinPeers()) {
if(!peerStorage_->isPeerAvailable()) {
break;
@ -180,8 +270,8 @@ void TrackerWatcherCommand::processTrackerResponse
break;
}
PeerInitiateConnectionCommand* command;
command = new PeerInitiateConnectionCommand(ncuid, requestGroup_, peer, e_,
btRuntime_);
command = new PeerInitiateConnectionCommand(ncuid, requestGroup_, peer,
e_, btRuntime_);
command->setPeerStorage(peerStorage_);
command->setPieceStorage(pieceStorage_);
e_->addCommand(command);
@ -190,13 +280,48 @@ void TrackerWatcherCommand::processTrackerResponse
}
}
SharedHandle<RequestGroup> TrackerWatcherCommand::createAnnounce() {
SharedHandle<RequestGroup> rg;
if(btAnnounce_->isAnnounceReady()) {
rg = createRequestGroup(btAnnounce_->getAnnounceUrl());
btAnnounce_->announceStart(); // inside it, trackers++.
SharedHandle<AnnRequest>
TrackerWatcherCommand::createAnnounce(DownloadEngine* e)
{
SharedHandle<AnnRequest> treq;
while(!btAnnounce_->isAllAnnounceFailed() &&
btAnnounce_->isAnnounceReady()) {
std::string uri = btAnnounce_->getAnnounceUrl();
uri_split_result res;
memset(&res, 0, sizeof(res));
if(uri_split(&res, uri.c_str()) == 0) {
// Without UDP tracker support, send it to normal tracker flow
// and make it fail.
if(udpTrackerClient_ &&
uri::getFieldString(res, USR_SCHEME, uri.c_str()) == "udp") {
uint16_t localPort;
localPort = e->getBtRegistry()->getUdpPort();
treq = createUDPAnnRequest
(uri::getFieldString(res, USR_HOST, uri.c_str()), res.port,
localPort);
} else {
treq = createHTTPAnnRequest(btAnnounce_->getAnnounceUrl());
}
return rg;
btAnnounce_->announceStart(); // inside it, trackers++.
break;
} else {
btAnnounce_->announceFailure();
}
}
if(btAnnounce_->isAllAnnounceFailed()) {
btAnnounce_->resetAnnounce();
}
return treq;
}
SharedHandle<AnnRequest>
TrackerWatcherCommand::createUDPAnnRequest(const std::string& host,
uint16_t port,
uint16_t localPort)
{
SharedHandle<UDPTrackerRequest> req =
btAnnounce_->createUDPTrackerRequest(host, port, localPort);
return SharedHandle<AnnRequest>(new UDPAnnRequest(req));
}
namespace {
@ -219,8 +344,8 @@ bool backupTrackerIsAvailable
}
} // namespace
SharedHandle<RequestGroup>
TrackerWatcherCommand::createRequestGroup(const std::string& uri)
SharedHandle<AnnRequest>
TrackerWatcherCommand::createHTTPAnnRequest(const std::string& uri)
{
std::vector<std::string> uris;
uris.push_back(uri);
@ -261,7 +386,7 @@ TrackerWatcherCommand::createRequestGroup(const std::string& uri)
dctx->setAcceptMetalink(false);
A2_LOG_INFO(fmt("Creating tracker request group GID#%s",
GroupId::toHex(rg->getGID()).c_str()));
return rg;
return SharedHandle<AnnRequest>(new HTTPAnnRequest(rg));
}
void TrackerWatcherCommand::setBtRuntime

View file

@ -50,6 +50,50 @@ class PieceStorage;
class BtRuntime;
class BtAnnounce;
class Option;
class UDPTrackerRequest;
class UDPTrackerClient;
class AnnRequest {
public:
virtual ~AnnRequest() {}
// Returns true if tracker request is finished, regardless of the
// outcome.
virtual bool stopped() const = 0;
// Returns true if tracker request is successful.
virtual bool success() const = 0;
// Returns true if issuing request is successful.
virtual bool issue(DownloadEngine* e) = 0;
// Stop this request.
virtual void stop(DownloadEngine* e) = 0;
// Returns true if processing tracker response is successful.
virtual bool processResponse(const SharedHandle<BtAnnounce>& btAnnounce) = 0;
};
class HTTPAnnRequest:public AnnRequest {
public:
HTTPAnnRequest(const SharedHandle<RequestGroup>& rg);
virtual ~HTTPAnnRequest();
virtual bool stopped() const;
virtual bool success() const;
virtual bool issue(DownloadEngine* e);
virtual void stop(DownloadEngine* e);
virtual bool processResponse(const SharedHandle<BtAnnounce>& btAnnounce);
private:
SharedHandle<RequestGroup> rg_;
};
class UDPAnnRequest:public AnnRequest {
public:
UDPAnnRequest(const SharedHandle<UDPTrackerRequest>& req);
virtual ~UDPAnnRequest();
virtual bool stopped() const;
virtual bool success() const;
virtual bool issue(DownloadEngine* e);
virtual void stop(DownloadEngine* e);
virtual bool processResponse(const SharedHandle<BtAnnounce>& btAnnounce);
private:
SharedHandle<UDPTrackerRequest> req_;
};
class TrackerWatcherCommand : public Command
{
@ -58,6 +102,8 @@ private:
DownloadEngine* e_;
SharedHandle<UDPTrackerClient> udpTrackerClient_;
SharedHandle<PeerStorage> peerStorage_;
SharedHandle<PieceStorage> pieceStorage_;
@ -66,16 +112,20 @@ private:
SharedHandle<BtAnnounce> btAnnounce_;
SharedHandle<RequestGroup> trackerRequestGroup_;
SharedHandle<AnnRequest> trackerRequest_;
/**
* Returns a command for announce request. Returns 0 if no announce request
* is needed.
*/
SharedHandle<RequestGroup> createRequestGroup(const std::string& url);
SharedHandle<AnnRequest>
createHTTPAnnRequest(const std::string& uri);
std::string getTrackerResponse(const SharedHandle<RequestGroup>& requestGroup);
SharedHandle<AnnRequest>
createUDPAnnRequest(const std::string& host, uint16_t port,
uint16_t localPort);
void processTrackerResponse(const std::string& response);
void addConnection();
const SharedHandle<Option>& getOption() const;
public:
@ -85,7 +135,7 @@ public:
virtual ~TrackerWatcherCommand();
SharedHandle<RequestGroup> createAnnounce();
SharedHandle<AnnRequest> createAnnounce(DownloadEngine* e);
virtual bool execute();

616
src/UDPTrackerClient.cc Normal file
View file

@ -0,0 +1,616 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "UDPTrackerClient.h"
#include "UDPTrackerRequest.h"
#include "bittorrent_helper.h"
#include "util.h"
#include "LogFactory.h"
#include "SimpleRandomizer.h"
#include "fmt.h"
namespace aria2 {
UDPTrackerClient::UDPTrackerClient()
{}
namespace {
template<typename InputIterator>
void failRequest(InputIterator first, InputIterator last, int error)
{
for(; first != last; ++first) {
(*first)->state = UDPT_STA_COMPLETE;
(*first)->error = error;
}
}
} // namespace
namespace {
int32_t generateTransactionId()
{
return SimpleRandomizer::getInstance()->getRandomNumber(INT32_MAX);
}
} // namespace
namespace {
void logInvalidLength(const std::string& remoteAddr, uint16_t remotePort,
int action, unsigned long expected, unsigned long actual)
{
A2_LOG_INFO(fmt("UDPT received %s reply from %s:%u invalid length "
"expected:%lu, actual:%lu",
getUDPTrackerActionStr(action),
remoteAddr.c_str(), remotePort, expected, actual));
}
} // namespace
namespace {
void logInvalidTransaction(const std::string& remoteAddr, uint16_t remotePort,
int action, int32_t transactionId)
{
A2_LOG_INFO(fmt("UDPT received %s reply from %s:%u invalid transaction_id=%d",
getUDPTrackerActionStr(action),
remoteAddr.c_str(), remotePort, transactionId));
}
} // namespace
namespace {
void logTooShortLength(const std::string& remoteAddr, uint16_t remotePort,
int action,
unsigned long minLength, unsigned long actual)
{
A2_LOG_INFO(fmt("UDPT received %s reply from %s:%u length too short "
"min:%lu, actual:%lu",
getUDPTrackerActionStr(action),
remoteAddr.c_str(), remotePort, minLength, actual));
}
} // namespace
UDPTrackerClient::~UDPTrackerClient()
{
// Make all contained requests fail
int error = UDPT_ERR_SHUTDOWN;
failRequest(inflightRequests_.begin(), inflightRequests_.end(), error);
failRequest(pendingRequests_.begin(), pendingRequests_.end(), error);
failRequest(connectRequests_.begin(), connectRequests_.end(), error);
}
namespace {
struct CollectAddrPortMatch {
bool operator()(const SharedHandle<UDPTrackerRequest>& req) const
{
if(req->remoteAddr == remoteAddr && req->remotePort == remotePort) {
dest.push_back(req);
return true;
} else {
return false;
}
}
std::vector<SharedHandle<UDPTrackerRequest> >& dest;
std::string remoteAddr;
uint16_t remotePort;
CollectAddrPortMatch(std::vector<SharedHandle<UDPTrackerRequest> >& dest,
const std::string& remoteAddr, uint16_t remotePort)
: dest(dest), remoteAddr(remoteAddr), remotePort(remotePort)
{}
};
} // namespace
int UDPTrackerClient::receiveReply
(const unsigned char* data, size_t length, const std::string& remoteAddr,
uint16_t remotePort, const Timer& now)
{
int32_t action = bittorrent::getIntParam(data, 0);
switch(action) {
case UDPT_ACT_CONNECT: {
if(length != 16) {
logInvalidLength(remoteAddr, remotePort, action, 16, length);
return -1;
}
int32_t transactionId = bittorrent::getIntParam(data, 4);
SharedHandle<UDPTrackerRequest> req =
findInflightRequest(remoteAddr, remotePort, transactionId, true);
if(!req) {
logInvalidTransaction(remoteAddr, remotePort, action, transactionId);
return -1;
}
req->state = UDPT_STA_COMPLETE;
int64_t connectionId = bittorrent::getLLIntParam(data, 8);
A2_LOG_INFO(fmt("UDPT received CONNECT reply from %s:%u transaction_id=%u,"
"connection_id=%"PRId64, remoteAddr.c_str(), remotePort,
transactionId, connectionId));
UDPTrackerConnection c(UDPT_CST_CONNECTED, connectionId, now);
connectionIdCache_[std::make_pair(remoteAddr, remotePort)] = c;
// Now we have connecion ID, push requests which are waiting for
// it.
std::vector<SharedHandle<UDPTrackerRequest> > reqs;
connectRequests_.erase(std::remove_if
(connectRequests_.begin(), connectRequests_.end(),
CollectAddrPortMatch(reqs, remoteAddr, remotePort)),
connectRequests_.end());
pendingRequests_.insert(pendingRequests_.begin(),
reqs.begin(), reqs.end());
break;
}
case UDPT_ACT_ANNOUNCE: {
if(length < 20) {
logTooShortLength(remoteAddr, remotePort, action, 20, length);
return - 1;
}
int32_t transactionId = bittorrent::getIntParam(data, 4);
SharedHandle<UDPTrackerRequest> req =
findInflightRequest(remoteAddr, remotePort, transactionId, true);
if(!req) {
logInvalidTransaction(remoteAddr, remotePort, action, transactionId);
return -1;
}
req->state = UDPT_STA_COMPLETE;
req->reply.reset(new UDPTrackerReply());
req->reply->action = action;
req->reply->transactionId = transactionId;
req->reply->interval = bittorrent::getIntParam(data, 8);
req->reply->leechers = bittorrent::getIntParam(data, 12);
req->reply->seeders = bittorrent::getIntParam(data, 16);
int numPeers = 0;
for(size_t i = 20; i < length; i += 6) {
std::pair<std::string, uint16_t> hostport =
bittorrent::unpackcompact(data+i, AF_INET);
if(!hostport.first.empty()) {
req->reply->peers.push_back(hostport);
++numPeers;
}
}
A2_LOG_INFO(fmt("UDPT received ANNOUNCE reply from %s:%u transaction_id=%u,"
"connection_id=%"PRId64", event=%s, infohash=%s, "
"interval=%d, leechers=%d, "
"seeders=%d, num_peers=%d", remoteAddr.c_str(), remotePort,
transactionId, req->connectionId,
getUDPTrackerEventStr(req->event),
util::toHex(req->infohash).c_str(),
req->reply->interval, req->reply->leechers,
req->reply->seeders, numPeers));
break;
}
case UDPT_ACT_ERROR: {
if(length < 8) {
logTooShortLength(remoteAddr, remotePort, action, 8, length);
return -1;
}
int32_t transactionId = bittorrent::getIntParam(data, 4);
SharedHandle<UDPTrackerRequest> req =
findInflightRequest(remoteAddr, remotePort, transactionId, true);
if(!req) {
logInvalidTransaction(remoteAddr, remotePort, action, transactionId);
return -1;
}
std::string errorString(data+8, data+length);
errorString = util::encodeNonUtf8(errorString);
req->state = UDPT_STA_COMPLETE;
req->error = UDPT_ERR_TRACKER;
A2_LOG_INFO(fmt("UDPT received ERROR reply from %s:%u transaction_id=%u,"
"connection_id=%"PRId64", action=%d, error_string=%s",
remoteAddr.c_str(),
remotePort, transactionId, req->connectionId, action,
errorString.c_str()));
if(req->action == UDPT_ACT_CONNECT) {
failConnect(req->remoteAddr, req->remotePort, UDPT_ERR_TRACKER);
}
break;
}
case UDPT_ACT_SCRAPE:
A2_LOG_INFO("unexpected scrape action reply");
return -1;
default:
A2_LOG_INFO("unknown action reply");
return -1;
}
return 0;
}
ssize_t UDPTrackerClient::createRequest
(unsigned char* data, size_t length, std::string& remoteAddr,
uint16_t& remotePort, const Timer& now)
{
if(pendingRequests_.empty()) {
return -1;
}
while(!pendingRequests_.empty()) {
const SharedHandle<UDPTrackerRequest>& req = pendingRequests_.front();
if(req->action == UDPT_ACT_CONNECT) {
ssize_t rv;
rv = createUDPTrackerConnect(data, length, remoteAddr, remotePort, req);
return rv;
}
UDPTrackerConnection* c = getConnectionId(req->remoteAddr,
req->remotePort,
now);
if(!c) {
SharedHandle<UDPTrackerRequest> creq(new UDPTrackerRequest());
creq->action = UDPT_ACT_CONNECT;
creq->remoteAddr = req->remoteAddr;
creq->remotePort = req->remotePort;
creq->transactionId = generateTransactionId();
pendingRequests_.push_front(creq);
ssize_t rv;
rv = createUDPTrackerConnect(data, length, remoteAddr, remotePort, creq);
return rv;
}
if(c->state == UDPT_CST_CONNECTING) {
connectRequests_.push_back(req);
pendingRequests_.pop_front();
continue;
}
req->connectionId = c->connectionId;
req->transactionId = generateTransactionId();
ssize_t rv;
rv = createUDPTrackerAnnounce(data, length, remoteAddr, remotePort, req);
return rv;
}
return -1;
}
void UDPTrackerClient::requestSent(const Timer& now)
{
if(pendingRequests_.empty()) {
A2_LOG_WARN("pendingRequests_ is empty");
return;
}
const SharedHandle<UDPTrackerRequest>& req = pendingRequests_.front();
switch(req->action) {
case UDPT_ACT_CONNECT:
A2_LOG_INFO(fmt("UDPT sent CONNECT to %s:%u transaction_id=%u",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId));
break;
case UDPT_ACT_ANNOUNCE:
A2_LOG_INFO(fmt("UDPT sent ANNOUNCE to %s:%u transaction_id=%u, "
"connection_id=%"PRId64", event=%s, infohash=%s",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId, req->connectionId,
getUDPTrackerEventStr(req->event),
util::toHex(req->infohash).c_str()));
break;
default:
// unreachable
assert(0);
}
req->dispatched = now;
switch(req->action) {
case UDPT_ACT_CONNECT: {
connectionIdCache_[std::make_pair(req->remoteAddr, req->remotePort)]
= UDPTrackerConnection();
break;
}
}
inflightRequests_.push_back(req);
pendingRequests_.pop_front();
}
void UDPTrackerClient::requestFail(int error)
{
if(pendingRequests_.empty()) {
A2_LOG_WARN("pendingRequests_ is empty");
return;
}
const SharedHandle<UDPTrackerRequest>& req = pendingRequests_.front();
switch(req->action) {
case UDPT_ACT_CONNECT:
A2_LOG_INFO(fmt("UDPT fail CONNECT to %s:%u transaction_id=%u",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId));
failConnect(req->remoteAddr, req->remotePort, error);
break;
case UDPT_ACT_ANNOUNCE:
A2_LOG_INFO(fmt("UDPT fail ANNOUNCE to %s:%u transaction_id=%u, "
"connection_id=%"PRId64", event=%s, infohash=%s",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId, req->connectionId,
getUDPTrackerEventStr(req->event),
util::toHex(req->infohash).c_str()));
break;
default:
// unreachable
assert(0);
}
req->state = UDPT_STA_COMPLETE;
req->error = error;
pendingRequests_.pop_front();
}
void UDPTrackerClient::addRequest(const SharedHandle<UDPTrackerRequest>& req)
{
req->state = UDPT_STA_PENDING;
req->error = UDPT_ERR_SUCCESS;
pendingRequests_.push_back(req);
}
namespace {
struct TimeoutCheck {
bool operator()(const SharedHandle<UDPTrackerRequest>& req) const
{
int t = req->dispatched.difference(now);
if(req->failCount == 0) {
if(t >= 15) {
switch(req->action) {
case UDPT_ACT_CONNECT:
A2_LOG_INFO(fmt("UDPT resend CONNECT to %s:%u transaction_id=%u",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId));
break;
case UDPT_ACT_ANNOUNCE:
A2_LOG_INFO(fmt("UDPT resend ANNOUNCE to %s:%u transaction_id=%u, "
"connection_id=%"PRId64", event=%s, infohash=%s",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId, req->connectionId,
getUDPTrackerEventStr(req->event),
util::toHex(req->infohash).c_str()));
break;
default:
// unreachable
assert(0);
}
++req->failCount;
dest.push_back(req);
return true;
} else {
return false;
}
} else {
if(t >= 60) {
switch(req->action) {
case UDPT_ACT_CONNECT:
A2_LOG_INFO(fmt("UDPT timeout CONNECT to %s:%u transaction_id=%u",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId));
client->failConnect(req->remoteAddr, req->remotePort,
UDPT_ERR_TIMEOUT);
break;
case UDPT_ACT_ANNOUNCE:
A2_LOG_INFO(fmt("UDPT timeout ANNOUNCE to %s:%u transaction_id=%u, "
"connection_id=%"PRId64", event=%s, infohash=%s",
req->remoteAddr.c_str(), req->remotePort,
req->transactionId, req->connectionId,
getUDPTrackerEventStr(req->event),
util::toHex(req->infohash).c_str()));
break;
default:
// unreachable
assert(0);
}
++req->failCount;
req->state = UDPT_STA_COMPLETE;
req->error = UDPT_ERR_TIMEOUT;
return true;
} else {
return false;
}
}
}
std::vector<SharedHandle<UDPTrackerRequest> >& dest;
UDPTrackerClient* client;
const Timer& now;
TimeoutCheck(std::vector<SharedHandle<UDPTrackerRequest> >& dest,
UDPTrackerClient* client,
const Timer& now)
: dest(dest), client(client), now(now)
{}
};
} // namespace
void UDPTrackerClient::handleTimeout(const Timer& now)
{
std::vector<SharedHandle<UDPTrackerRequest> > dest;
inflightRequests_.erase(std::remove_if(inflightRequests_.begin(),
inflightRequests_.end(),
TimeoutCheck(dest, this, now)),
inflightRequests_.end());
pendingRequests_.insert(pendingRequests_.begin(), dest.begin(), dest.end());
}
SharedHandle<UDPTrackerRequest> UDPTrackerClient::findInflightRequest
(const std::string& remoteAddr, uint16_t remotePort, int32_t transactionId,
bool remove)
{
SharedHandle<UDPTrackerRequest> res;
for(std::deque<SharedHandle<UDPTrackerRequest> >::iterator i =
inflightRequests_.begin(), eoi = inflightRequests_.end(); i != eoi;
++i) {
if((*i)->remoteAddr == remoteAddr && (*i)->remotePort == remotePort &&
(*i)->transactionId == transactionId) {
res = *i;
if(remove) {
inflightRequests_.erase(i);
}
break;
}
}
return res;
}
UDPTrackerConnection* UDPTrackerClient::getConnectionId
(const std::string& remoteAddr, uint16_t remotePort, const Timer& now)
{
std::map<std::pair<std::string, uint16_t>,
UDPTrackerConnection>::iterator i =
connectionIdCache_.find(std::make_pair(remoteAddr, remotePort));
if(i == connectionIdCache_.end()) {
return 0;
}
if((*i).second.state == UDPT_CST_CONNECTED &&
(*i).second.lastUpdated.difference(now) > 60) {
connectionIdCache_.erase(i);
return 0;
} else {
return &(*i).second;
}
}
namespace {
struct FailConnectDelete {
bool operator()(const SharedHandle<UDPTrackerRequest>& req) const
{
if(req->action == UDPT_ACT_ANNOUNCE &&
req->remoteAddr == remoteAddr && req->remotePort == remotePort) {
A2_LOG_INFO(fmt("Force fail infohash=%s",
util::toHex(req->infohash).c_str()));
req->state = UDPT_STA_COMPLETE;
req->error = error;
return true;
} else {
return false;
}
}
std::string remoteAddr;
uint16_t remotePort;
int error;
FailConnectDelete(const std::string& remoteAddr, uint16_t remotePort,
int error)
: remoteAddr(remoteAddr), remotePort(remotePort), error(error)
{}
};
} // namespace
void UDPTrackerClient::failConnect(const std::string& remoteAddr,
uint16_t remotePort, int error)
{
connectionIdCache_.erase(std::make_pair(remoteAddr, remotePort));
// Fail all requests which are waiting for connecion ID of the host.
connectRequests_.erase(std::remove_if(connectRequests_.begin(),
connectRequests_.end(),
FailConnectDelete
(remoteAddr, remotePort, error)),
connectRequests_.end());
pendingRequests_.erase(std::remove_if(pendingRequests_.begin(),
pendingRequests_.end(),
FailConnectDelete
(remoteAddr, remotePort, error)),
pendingRequests_.end());
}
void UDPTrackerClient::failAll()
{
int error = UDPT_ERR_SHUTDOWN;
failRequest(inflightRequests_.begin(), inflightRequests_.end(), error);
failRequest(pendingRequests_.begin(), pendingRequests_.end(), error);
failRequest(connectRequests_.begin(), connectRequests_.end(), error);
}
void UDPTrackerClient::increaseWatchers()
{
++numWatchers_;
}
void UDPTrackerClient::decreaseWatchers()
{
--numWatchers_;
}
ssize_t createUDPTrackerConnect
(unsigned char* data, size_t length,
std::string& remoteAddr, uint16_t& remotePort,
const SharedHandle<UDPTrackerRequest>& req)
{
assert(length >= 16);
remoteAddr = req->remoteAddr;
remotePort = req->remotePort;
bittorrent::setLLIntParam(data, UDPT_INITIAL_CONNECTION_ID);
bittorrent::setIntParam(data+8, req->action);
bittorrent::setIntParam(data+12, req->transactionId);
return 16;
}
ssize_t createUDPTrackerAnnounce
(unsigned char* data, size_t length,
std::string& remoteAddr, uint16_t& remotePort,
const SharedHandle<UDPTrackerRequest>& req)
{
assert(length >= 100);
remoteAddr = req->remoteAddr;
remotePort = req->remotePort;
bittorrent::setLLIntParam(data, req->connectionId);
bittorrent::setIntParam(data+8, req->action);
bittorrent::setIntParam(data+12, req->transactionId);
memcpy(data+16, req->infohash.c_str(), req->infohash.size());
memcpy(data+36, req->peerId.c_str(), req->peerId.size());
bittorrent::setLLIntParam(data+56, req->downloaded);
bittorrent::setLLIntParam(data+64, req->left);
bittorrent::setLLIntParam(data+72, req->uploaded);
bittorrent::setIntParam(data+80, req->event);
// ip is already network-byte order
memcpy(data+84, &req->ip, sizeof(req->ip));
bittorrent::setIntParam(data+88, req->key);
bittorrent::setIntParam(data+92, req->numWant);
bittorrent::setShortIntParam(data+96, req->port);
// extensions is always 0
bittorrent::setShortIntParam(data+98, 0);
return 100;
}
const char* getUDPTrackerActionStr(int action)
{
switch(action) {
case UDPT_ACT_CONNECT:
return "CONNECT";
case UDPT_ACT_ANNOUNCE:
return "ANNOUNCE";
case UDPT_ACT_ERROR:
return "ERROR";
default:
return "(unknown)";
}
}
const char* getUDPTrackerEventStr(int event)
{
switch(event) {
case UDPT_EVT_NONE:
return "NONE";
case UDPT_EVT_COMPLETED:
return "COMPLETED";
case UDPT_EVT_STARTED:
return "STARTED";
case UDPT_EVT_STOPPED:
return "STOPPED";
default:
return "(unknown)";
}
}
} // namespace aria2

171
src/UDPTrackerClient.h Normal file
View file

@ -0,0 +1,171 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_UDP_TRACKER_CLIENT_H
#define D_UDP_TRACKER_CLIENT_H
#include "common.h"
#include <string>
#include <deque>
#include <map>
#include "SharedHandle.h"
#include "TimerA2.h"
namespace aria2 {
#define UDPT_INITIAL_CONNECTION_ID 0x41727101980LL
class UDPTrackerRequest;
enum UDPTrackerConnectionState {
UDPT_CST_CONNECTING,
UDPT_CST_CONNECTED
};
struct UDPTrackerConnection {
int state;
int64_t connectionId;
Timer lastUpdated;
UDPTrackerConnection()
: state(UDPT_CST_CONNECTING),
connectionId(UDPT_INITIAL_CONNECTION_ID),
lastUpdated(0)
{}
UDPTrackerConnection(int state, int64_t connectionId,
const Timer& lastUpdated)
: state(state),
connectionId(connectionId),
lastUpdated(lastUpdated)
{}
};
class UDPTrackerClient {
public:
UDPTrackerClient();
~UDPTrackerClient();
int receiveReply
(const unsigned char* data, size_t length, const std::string& remoteAddr,
uint16_t remotePort, const Timer& now);
// Creates data frame for the next pending request. This function
// always processes first entry of pendingRequests_. If the data is
// sent successfully, call requestSent(). Otherwise call
// requestFail().
ssize_t createRequest
(unsigned char* data, size_t length, std::string& remoteAddr,
uint16_t& remotePort, const Timer& now);
// Tells this object that first entry of pendingRequests_ is
// successfully sent.
void requestSent(const Timer& now);
// Tells this object that first entry of pendingRequests_ is not
// successfully sent. The |error| should indicate error situation.
void requestFail(int error);
void addRequest(const SharedHandle<UDPTrackerRequest>& req);
// Handles timeout for inflight requests.
void handleTimeout(const Timer& now);
const std::deque<SharedHandle<UDPTrackerRequest> >&
getPendingRequests() const
{
return pendingRequests_;
}
const std::deque<SharedHandle<UDPTrackerRequest> >&
getConnectRequests() const
{
return connectRequests_;
}
const std::deque<SharedHandle<UDPTrackerRequest> >&
getInflightRequests() const
{
return inflightRequests_;
}
bool noRequest() const
{
return pendingRequests_.empty() && connectRequests_.empty() &&
getInflightRequests().empty();
}
// Makes all contained requests fail.
void failAll();
int getNumWatchers() const
{
return numWatchers_;
}
void increaseWatchers();
void decreaseWatchers();
// Actually private function, but made public, to be used by unnamed
// function.
void failConnect(const std::string& remoteAddr, uint16_t remotePort,
int error);
private:
SharedHandle<UDPTrackerRequest> findInflightRequest
(const std::string& remoteAddr, uint16_t remotePort, int32_t transactionId,
bool remove);
UDPTrackerConnection* getConnectionId
(const std::string& remoteAddr, uint16_t remotePort, const Timer& now);
std::map<std::pair<std::string, uint16_t>,
UDPTrackerConnection> connectionIdCache_;
std::deque<SharedHandle<UDPTrackerRequest> > inflightRequests_;
std::deque<SharedHandle<UDPTrackerRequest> > pendingRequests_;
std::deque<SharedHandle<UDPTrackerRequest> > connectRequests_;
int numWatchers_;
};
ssize_t createUDPTrackerConnect
(unsigned char* data, size_t length, std::string& remoteAddr,
uint16_t& remotePort, const SharedHandle<UDPTrackerRequest>& req);
ssize_t createUDPTrackerAnnounce
(unsigned char* data, size_t length, std::string& remoteAddr,
uint16_t& remotePort, const SharedHandle<UDPTrackerRequest>& req);
const char* getUDPTrackerActionStr(int action);
const char* getUDPTrackerEventStr(int event);
} // namespace aria2
#endif // D_UDP_TRACKER_CLIENT_H

51
src/UDPTrackerRequest.cc Normal file
View file

@ -0,0 +1,51 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "UDPTrackerRequest.h"
namespace aria2 {
UDPTrackerReply::UDPTrackerReply()
: action(0), transactionId(0), interval(0), leechers(0), seeders(0)
{}
UDPTrackerRequest::UDPTrackerRequest()
: remotePort(0), action(UDPT_ACT_CONNECT), transactionId(0), downloaded(0),
left(0), uploaded(0), event(UDPT_EVT_NONE), ip(0), key(0), numWant(0),
port(0), extensions(0), state(UDPT_STA_PENDING), error(UDPT_ERR_SUCCESS),
dispatched(0),
failCount(0)
{}
} // namespace aria2

112
src/UDPTrackerRequest.h Normal file
View file

@ -0,0 +1,112 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_UDP_TRACKER_REQUEST_H
#define D_UDP_TRACKER_REQUEST_H
#include "common.h"
#include <string>
#include <vector>
#include "SharedHandle.h"
#include "TimerA2.h"
namespace aria2 {
enum UDPTrackerAction {
UDPT_ACT_CONNECT = 0,
UDPT_ACT_ANNOUNCE = 1,
UDPT_ACT_SCRAPE = 2,
UDPT_ACT_ERROR = 3
};
enum UDPTrackerError {
UDPT_ERR_SUCCESS,
UDPT_ERR_TRACKER,
UDPT_ERR_TIMEOUT,
UDPT_ERR_NETWORK,
UDPT_ERR_SHUTDOWN
};
enum UDPTrackerState {
UDPT_STA_PENDING,
UDPT_STA_COMPLETE
};
enum UDPTrackerEvent {
UDPT_EVT_NONE = 0,
UDPT_EVT_COMPLETED = 1,
UDPT_EVT_STARTED = 2,
UDPT_EVT_STOPPED = 3
};
struct UDPTrackerReply {
int32_t action;
int32_t transactionId;
int32_t interval;
int32_t leechers;
int32_t seeders;
std::vector<std::pair<std::string, uint16_t> > peers;
UDPTrackerReply();
};
struct UDPTrackerRequest {
std::string remoteAddr;
uint16_t remotePort;
int64_t connectionId;
int32_t action;
int32_t transactionId;
std::string infohash;
std::string peerId;
int64_t downloaded;
int64_t left;
int64_t uploaded;
int32_t event;
uint32_t ip;
uint32_t key;
int32_t numWant;
uint16_t port;
uint16_t extensions;
int state;
int error;
Timer dispatched;
int failCount;
SharedHandle<UDPTrackerReply> reply;
UDPTrackerRequest();
};
} // namespace aria2
#endif // D_UDP_TRACKER_REQUEST_H

View file

@ -731,6 +731,13 @@ uint8_t getId(const unsigned char* msg)
return msg[0];
}
uint64_t getLLIntParam(const unsigned char* msg, size_t pos)
{
uint64_t nParam;
memcpy(&nParam, msg+pos, sizeof(nParam));
return ntoh64(nParam);
}
uint32_t getIntParam(const unsigned char* msg, size_t pos)
{
uint32_t nParam;
@ -798,6 +805,12 @@ void checkBitfield
}
}
void setLLIntParam(unsigned char* dest, uint64_t param)
{
uint64_t nParam = hton64(param);
memcpy(dest, &nParam, sizeof(nParam));
}
void setIntParam(unsigned char* dest, uint32_t param)
{
uint32_t nParam = htonl(param);

View file

@ -160,6 +160,11 @@ getInfoHash(const SharedHandle<DownloadContext>& downloadContext);
std::string
getInfoHashString(const SharedHandle<DownloadContext>& downloadContext);
// Returns 8bytes unsigned integer located at offset pos. The integer
// in msg is network byte order. This function converts it into host
// byte order and returns it.
uint64_t getLLIntParam(const unsigned char* msg, size_t pos);
// Returns 4bytes unsigned integer located at offset pos. The integer
// in msg is network byte order. This function converts it into host
// byte order and returns it.
@ -170,6 +175,10 @@ uint32_t getIntParam(const unsigned char* msg, size_t pos);
// byte order and returns it.
uint16_t getShortIntParam(const unsigned char* msg, size_t pos);
// Put param at location pointed by dest. param is converted into
// network byte order.
void setLLIntParam(unsigned char* dest, uint64_t param);
// Put param at location pointed by dest. param is converted into
// network byte order.
void setIntParam(unsigned char* dest, uint32_t param);

View file

@ -11,6 +11,7 @@
#include "BtRuntime.h"
#include "FileEntry.h"
#include "bittorrent_helper.h"
#include "UDPTrackerRequest.h"
namespace aria2 {

View file

@ -18,6 +18,8 @@
#include "DownloadContext.h"
#include "bittorrent_helper.h"
#include "array_fun.h"
#include "UDPTrackerRequest.h"
#include "SocketCore.h"
namespace aria2 {
@ -34,6 +36,7 @@ class DefaultBtAnnounceTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testProcessAnnounceResponse_malformed);
CPPUNIT_TEST(testProcessAnnounceResponse_failureReason);
CPPUNIT_TEST(testProcessAnnounceResponse);
CPPUNIT_TEST(testProcessUDPTrackerResponse);
CPPUNIT_TEST_SUITE_END();
private:
SharedHandle<DownloadContext> dctx_;
@ -87,6 +90,7 @@ public:
void testProcessAnnounceResponse_malformed();
void testProcessAnnounceResponse_failureReason();
void testProcessAnnounceResponse();
void testProcessUDPTrackerResponse();
};
@ -197,24 +201,49 @@ void DefaultBtAnnounceTest::testGetAnnounceUrl()
btAnnounce.setBtRuntime(btRuntime_);
btAnnounce.setRandomizer(SharedHandle<Randomizer>(new FixedNumberRandomizer()));
btAnnounce.setTcpPort(6989);
SharedHandle<UDPTrackerRequest> req;
CPPUNIT_ASSERT_EQUAL(std::string("http://localhost/announce?info_hash=%01%23Eg%89%AB%CD%EF%01%23Eg%89%AB%CD%EF%01%23Eg&peer_id=%2Daria2%2Dultrafastdltl&uploaded=1572864&downloaded=1310720&left=1572864&compact=1&key=fastdltl&numwant=50&no_peer_id=1&port=6989&event=started&supportcrypto=1"), btAnnounce.getAnnounceUrl());
req = btAnnounce.createUDPTrackerRequest("localhost", 80, 6989);
CPPUNIT_ASSERT_EQUAL(std::string("localhost"), req->remoteAddr);
CPPUNIT_ASSERT_EQUAL((uint16_t)80, req->remotePort);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE, req->action);
CPPUNIT_ASSERT_EQUAL(bittorrent::getInfoHashString(dctx_),
util::toHex(req->infohash));
CPPUNIT_ASSERT_EQUAL(std::string("-aria2-ultrafastdltl"), req->peerId);
CPPUNIT_ASSERT_EQUAL((int64_t)1310720, req->downloaded);
CPPUNIT_ASSERT_EQUAL((int64_t)1572864, req->left);
CPPUNIT_ASSERT_EQUAL((int64_t)1572864, req->uploaded);
CPPUNIT_ASSERT_EQUAL((int)UDPT_EVT_STARTED, req->event);
CPPUNIT_ASSERT_EQUAL((uint32_t)0, req->ip);
CPPUNIT_ASSERT_EQUAL((int32_t)50, req->numWant);
CPPUNIT_ASSERT_EQUAL((uint16_t)6989, req->port);
CPPUNIT_ASSERT_EQUAL((uint16_t)0, req->extensions);
btAnnounce.announceSuccess();
CPPUNIT_ASSERT_EQUAL(std::string("http://localhost/announce?info_hash=%01%23Eg%89%AB%CD%EF%01%23Eg%89%AB%CD%EF%01%23Eg&peer_id=%2Daria2%2Dultrafastdltl&uploaded=1572864&downloaded=1310720&left=1572864&compact=1&key=fastdltl&numwant=50&no_peer_id=1&port=6989&supportcrypto=1"), btAnnounce.getAnnounceUrl());
req = btAnnounce.createUDPTrackerRequest("localhost", 80, 6989);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE, req->action);
CPPUNIT_ASSERT_EQUAL((int)UDPT_EVT_NONE, req->event);
btAnnounce.announceSuccess();
pieceStorage_->setAllDownloadFinished(true);
CPPUNIT_ASSERT_EQUAL(std::string("http://localhost/announce?info_hash=%01%23Eg%89%AB%CD%EF%01%23Eg%89%AB%CD%EF%01%23Eg&peer_id=%2Daria2%2Dultrafastdltl&uploaded=1572864&downloaded=1310720&left=1572864&compact=1&key=fastdltl&numwant=50&no_peer_id=1&port=6989&event=completed&supportcrypto=1"), btAnnounce.getAnnounceUrl());
req = btAnnounce.createUDPTrackerRequest("localhost", 80, 6989);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE, req->action);
CPPUNIT_ASSERT_EQUAL((int)UDPT_EVT_COMPLETED, req->event);
btAnnounce.announceSuccess();
btRuntime_->setHalt(true);
CPPUNIT_ASSERT_EQUAL(std::string("http://localhost/announce?info_hash=%01%23Eg%89%AB%CD%EF%01%23Eg%89%AB%CD%EF%01%23Eg&peer_id=%2Daria2%2Dultrafastdltl&uploaded=1572864&downloaded=1310720&left=1572864&compact=1&key=fastdltl&numwant=0&no_peer_id=1&port=6989&event=stopped&supportcrypto=1"), btAnnounce.getAnnounceUrl());
req = btAnnounce.createUDPTrackerRequest("localhost", 80, 6989);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE, req->action);
CPPUNIT_ASSERT_EQUAL((int)UDPT_EVT_STOPPED, req->event);
}
void DefaultBtAnnounceTest::testGetAnnounceUrl_withQuery()
@ -262,6 +291,13 @@ void DefaultBtAnnounceTest::testGetAnnounceUrl_externalIP()
"key=fastdltl&numwant=50&no_peer_id=1&port=6989&event=started&"
"supportcrypto=1&ip=192.168.1.1"),
btAnnounce.getAnnounceUrl());
SharedHandle<UDPTrackerRequest> req;
req = btAnnounce.createUDPTrackerRequest("localhost", 80, 6989);
char host[NI_MAXHOST];
int rv = inetNtop(AF_INET, &req->ip, host, sizeof(host));
CPPUNIT_ASSERT_EQUAL(0, rv);
CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.1"), std::string(host));
}
void DefaultBtAnnounceTest::testIsAllAnnounceFailed()
@ -411,4 +447,34 @@ void DefaultBtAnnounceTest::testProcessAnnounceResponse()
peer->getIPAddress());
}
void DefaultBtAnnounceTest::testProcessUDPTrackerResponse()
{
SharedHandle<UDPTrackerRequest> req(new UDPTrackerRequest());
req->action = UDPT_ACT_ANNOUNCE;
SharedHandle<UDPTrackerReply> reply(new UDPTrackerReply());
reply->interval = 1800;
reply->leechers = 200;
reply->seeders = 100;
for(int i = 0; i < 2; ++i) {
reply->peers.push_back(std::make_pair("192.168.0."+util::uitos(i+1),
6890+i));
}
req->reply = reply;
DefaultBtAnnounce an(dctx_, option_);
an.setPeerStorage(peerStorage_);
an.setBtRuntime(btRuntime_);
an.processUDPTrackerResponse(req);
CPPUNIT_ASSERT_EQUAL((time_t)1800, an.getInterval());
CPPUNIT_ASSERT_EQUAL((time_t)1800, an.getMinInterval());
CPPUNIT_ASSERT_EQUAL(100, an.getComplete());
CPPUNIT_ASSERT_EQUAL(200, an.getIncomplete());
CPPUNIT_ASSERT_EQUAL((size_t)2, peerStorage_->getUnusedPeers().size());
for(int i = 0; i < 2; ++i) {
SharedHandle<Peer> peer;
peer = peerStorage_->getUnusedPeers()[i];
CPPUNIT_ASSERT_EQUAL("192.168.0."+util::uitos(i+1), peer->getIPAddress());
CPPUNIT_ASSERT_EQUAL((uint16_t)(6890+i), peer->getPort());
}
}
} // namespace aria2

View file

@ -214,7 +214,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\
Bencode2Test.cc\
PeerConnectionTest.cc\
ValueBaseBencodeParserTest.cc\
ExtensionMessageRegistryTest.cc
ExtensionMessageRegistryTest.cc\
UDPTrackerClientTest.cc
endif # ENABLE_BITTORRENT
if ENABLE_METALINK

View file

@ -26,6 +26,12 @@ public:
return announceUrl;
}
virtual SharedHandle<UDPTrackerRequest>
createUDPTrackerRequest(const std::string& remoteAddr, uint16_t remotePort,
uint16_t localPort) {
return SharedHandle<UDPTrackerRequest>();
}
void setAnnounceUrl(const std::string& url) {
this->announceUrl = url;
}
@ -45,6 +51,9 @@ public:
virtual void processAnnounceResponse(const unsigned char* trackerResponse,
size_t trackerResponseLength) {}
virtual void processUDPTrackerResponse
(const SharedHandle<UDPTrackerRequest>& req) {}
virtual bool noMoreAnnounce() {
return false;
}

View file

@ -0,0 +1,453 @@
#include "UDPTrackerClient.h"
#include <cstring>
#include <cppunit/extensions/HelperMacros.h>
#include "TestUtil.h"
#include "UDPTrackerRequest.h"
#include "bittorrent_helper.h"
#include "wallclock.h"
namespace aria2 {
class UDPTrackerClientTest:public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(UDPTrackerClientTest);
CPPUNIT_TEST(testCreateUDPTrackerConnect);
CPPUNIT_TEST(testCreateUDPTrackerAnnounce);
CPPUNIT_TEST(testConnectFollowedByAnnounce);
CPPUNIT_TEST(testRequestFailure);
CPPUNIT_TEST(testTimeout);
CPPUNIT_TEST_SUITE_END();
public:
void setUp()
{
}
void testCreateUDPTrackerConnect();
void testCreateUDPTrackerAnnounce();
void testConnectFollowedByAnnounce();
void testRequestFailure();
void testTimeout();
};
CPPUNIT_TEST_SUITE_REGISTRATION( UDPTrackerClientTest );
namespace {
SharedHandle<UDPTrackerRequest> createAnnounce(const std::string& remoteAddr,
uint16_t remotePort,
int32_t transactionId)
{
SharedHandle<UDPTrackerRequest> req(new UDPTrackerRequest());
req->connectionId = INT64_MAX;
req->action = UDPT_ACT_ANNOUNCE;
req->remoteAddr = remoteAddr;
req->remotePort = remotePort;
req->transactionId = transactionId;
req->infohash = "bittorrent-infohash-";
req->peerId = "bittorrent-peer-id--";
req->downloaded = INT64_MAX - 1;
req->left = INT64_MAX - 2;
req->uploaded = INT64_MAX - 3;
req->event = UDPT_EVT_STARTED;
req->ip = 0;
req->key = 1000000007;
req->numWant = 50;
req->port = 6889;
req->extensions = 0;
return req;
}
} // namespace
namespace {
ssize_t createErrorReply(unsigned char* data, size_t len,
int32_t transactionId, const std::string& errorString)
{
bittorrent::setIntParam(data, UDPT_ACT_ERROR);
bittorrent::setIntParam(data+4, transactionId);
memcpy(data+8, errorString.c_str(), errorString.size());
return 8+errorString.size();
}
} // namespace
namespace {
ssize_t createConnectReply(unsigned char* data, size_t len,
uint64_t connectionId, int32_t transactionId)
{
bittorrent::setIntParam(data, UDPT_ACT_CONNECT);
bittorrent::setIntParam(data+4, transactionId);
bittorrent::setLLIntParam(data+8, connectionId);
return 16;
}
} // namespace
namespace {
ssize_t createAnnounceReply(unsigned char*data, size_t len,
int32_t transactionId, int numPeers = 0)
{
bittorrent::setIntParam(data, UDPT_ACT_ANNOUNCE);
bittorrent::setIntParam(data+4, transactionId);
bittorrent::setIntParam(data+8, 1800);
bittorrent::setIntParam(data+12, 100);
bittorrent::setIntParam(data+16, 256);
for(int i = 0; i < numPeers; ++i) {
bittorrent::packcompact(data+20+6*i, "192.168.0."+util::uitos(i+1),
6990+i);
}
return 20 + 6 * numPeers;
}
} // namespace
void UDPTrackerClientTest::testCreateUDPTrackerConnect()
{
unsigned char data[16];
std::string remoteAddr;
uint16_t remotePort = 0;
SharedHandle<UDPTrackerRequest> req(new UDPTrackerRequest());
req->action = UDPT_ACT_CONNECT;
req->remoteAddr = "192.168.0.1";
req->remotePort = 6991;
req->transactionId = 1000000009;
ssize_t rv = createUDPTrackerConnect(data, sizeof(data), remoteAddr,
remotePort, req);
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
CPPUNIT_ASSERT_EQUAL(req->remoteAddr, remoteAddr);
CPPUNIT_ASSERT_EQUAL(req->remotePort, remotePort);
CPPUNIT_ASSERT_EQUAL((int64_t)UDPT_INITIAL_CONNECTION_ID,
(int64_t)bittorrent::getLLIntParam(data, 0));
CPPUNIT_ASSERT_EQUAL((int)req->action, (int)bittorrent::getIntParam(data, 8));
CPPUNIT_ASSERT_EQUAL(req->transactionId,
(int32_t)bittorrent::getIntParam(data, 12));
}
void UDPTrackerClientTest::testCreateUDPTrackerAnnounce()
{
unsigned char data[100];
std::string remoteAddr;
uint16_t remotePort = 0;
SharedHandle<UDPTrackerRequest> req(createAnnounce("192.168.0.1", 6991,
1000000009));
ssize_t rv = createUDPTrackerAnnounce(data, sizeof(data), remoteAddr,
remotePort, req);
CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
CPPUNIT_ASSERT_EQUAL(req->connectionId,
(int64_t)bittorrent::getLLIntParam(data, 0));
CPPUNIT_ASSERT_EQUAL((int)req->action, (int)bittorrent::getIntParam(data, 8));
CPPUNIT_ASSERT_EQUAL(req->transactionId,
(int32_t)bittorrent::getIntParam(data, 12));
CPPUNIT_ASSERT_EQUAL(req->infohash, std::string(&data[16], &data[36]));
CPPUNIT_ASSERT_EQUAL(req->peerId, std::string(&data[36], &data[56]));
CPPUNIT_ASSERT_EQUAL(req->downloaded,
(int64_t)bittorrent::getLLIntParam(data, 56));
CPPUNIT_ASSERT_EQUAL(req->left,
(int64_t)bittorrent::getLLIntParam(data, 64));
CPPUNIT_ASSERT_EQUAL(req->uploaded,
(int64_t)bittorrent::getLLIntParam(data, 72));
CPPUNIT_ASSERT_EQUAL(req->event, (int32_t)bittorrent::getIntParam(data, 80));
CPPUNIT_ASSERT_EQUAL(req->ip, bittorrent::getIntParam(data, 84));
CPPUNIT_ASSERT_EQUAL(req->key, bittorrent::getIntParam(data, 88));
CPPUNIT_ASSERT_EQUAL(req->numWant,
(int32_t)bittorrent::getIntParam(data, 92));
CPPUNIT_ASSERT_EQUAL(req->port, bittorrent::getShortIntParam(data, 96));
CPPUNIT_ASSERT_EQUAL(req->extensions, bittorrent::getShortIntParam(data, 98));
}
void UDPTrackerClientTest::testConnectFollowedByAnnounce()
{
ssize_t rv;
UDPTrackerClient tr;
unsigned char data[100];
std::string remoteAddr;
uint16_t remotePort;
Timer now;
SharedHandle<UDPTrackerRequest> req1(createAnnounce("192.168.0.1", 6991, 0));
SharedHandle<UDPTrackerRequest> req2(createAnnounce("192.168.0.1", 6991, 0));
req2->infohash = "bittorrent-infohash2";
tr.addRequest(req1);
tr.addRequest(req2);
CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// CONNECT request was inserted
CPPUNIT_ASSERT_EQUAL((size_t)3, tr.getPendingRequests().size());
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
CPPUNIT_ASSERT_EQUAL(req1->remoteAddr, remoteAddr);
CPPUNIT_ASSERT_EQUAL(req1->remotePort, remotePort);
CPPUNIT_ASSERT_EQUAL((int64_t)UDPT_INITIAL_CONNECTION_ID,
(int64_t)bittorrent::getLLIntParam(data, 0));
int32_t transactionId = bittorrent::getIntParam(data, 12);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// Duplicate CONNECT request was not inserted
CPPUNIT_ASSERT_EQUAL((size_t)3, tr.getPendingRequests().size());
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
tr.requestSent(now);
// CONNECT request was moved to inflight
CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// Now all pending requests were moved to connect
CPPUNIT_ASSERT_EQUAL((ssize_t)-1, rv);
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
int64_t connectionId = 12345;
rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL(0, (int)rv);
// Now 2 requests get back to pending
CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// Creates announce for req1
CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
CPPUNIT_ASSERT_EQUAL(connectionId,
(int64_t)bittorrent::getLLIntParam(data, 0));
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
(int)bittorrent::getIntParam(data, 8));
CPPUNIT_ASSERT_EQUAL(req1->infohash,
std::string(&data[16], &data[36]));
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// Don't duplicate same request data
CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
int32_t transactionId1 = bittorrent::getIntParam(data, 12);
tr.requestSent(now);
CPPUNIT_ASSERT_EQUAL((size_t)1, tr.getPendingRequests().size());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
int32_t transactionId2 = bittorrent::getIntParam(data, 12);
// Creates announce for req2
CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
CPPUNIT_ASSERT_EQUAL((size_t)1, tr.getPendingRequests().size());
CPPUNIT_ASSERT_EQUAL(connectionId,
(int64_t)bittorrent::getLLIntParam(data, 0));
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
(int)bittorrent::getIntParam(data, 8));
CPPUNIT_ASSERT_EQUAL(req2->infohash,
std::string(&data[16], &data[36]));
tr.requestSent(now);
// Now all requests are inflight
CPPUNIT_ASSERT_EQUAL((size_t)0, tr.getPendingRequests().size());
// Reply for req2
rv = createAnnounceReply(data, sizeof(data), transactionId2);
rv = tr.receiveReply(data, rv, req2->remoteAddr, req2->remotePort, now);
CPPUNIT_ASSERT_EQUAL(0, (int)rv);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_SUCCESS, req2->error);
// Reply for req1
rv = createAnnounceReply(data, sizeof(data), transactionId1, 2);
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL(0, (int)rv);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_SUCCESS, req1->error);
CPPUNIT_ASSERT_EQUAL((size_t)2, req1->reply->peers.size());
for(int i = 0; i < 2; ++i) {
CPPUNIT_ASSERT_EQUAL("192.168.0."+util::uitos(i+1),
req1->reply->peers[i].first);
CPPUNIT_ASSERT_EQUAL((uint16_t)(6990+i), req1->reply->peers[i].second);
}
// Since we have connection ID, next announce request can be sent
// immediately
SharedHandle<UDPTrackerRequest> req3(createAnnounce("192.168.0.1", 6991, 0));
req3->infohash = "bittorrent-infohash3";
tr.addRequest(req3);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
CPPUNIT_ASSERT_EQUAL(req3->infohash,
std::string(&data[16], &data[36]));
tr.requestSent(now);
SharedHandle<UDPTrackerRequest> req4(createAnnounce("192.168.0.1", 6991, 0));
req4->infohash = "bittorrent-infohash4";
tr.addRequest(req4);
Timer future = now;
future.advance(3600);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort,
future);
// connection ID is stale because of the timeout
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
CPPUNIT_ASSERT_EQUAL((int64_t)UDPT_INITIAL_CONNECTION_ID,
(int64_t)bittorrent::getLLIntParam(data, 0));
}
void UDPTrackerClientTest::testRequestFailure()
{
ssize_t rv;
UDPTrackerClient tr;
unsigned char data[100];
std::string remoteAddr;
uint16_t remotePort;
Timer now;
{
SharedHandle<UDPTrackerRequest> req1
(createAnnounce("192.168.0.1", 6991, 0));
SharedHandle<UDPTrackerRequest> req2
(createAnnounce("192.168.0.1", 6991, 0));
tr.addRequest(req1);
tr.addRequest(req2);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
tr.requestFail(UDPT_ERR_NETWORK);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_NETWORK, req1->error);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_NETWORK, req2->error);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
}
{
SharedHandle<UDPTrackerRequest> req1
(createAnnounce("192.168.0.1", 6991, 0));
SharedHandle<UDPTrackerRequest> req2
(createAnnounce("192.168.0.1", 6991, 0));
tr.addRequest(req1);
tr.addRequest(req2);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
int32_t transactionId = bittorrent::getIntParam(data, 12);
tr.requestSent(now);
rv = createErrorReply(data, sizeof(data), transactionId, "error");
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TRACKER, req1->error);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TRACKER, req2->error);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
}
{
SharedHandle<UDPTrackerRequest> req1
(createAnnounce("192.168.0.1", 6991, 0));
tr.addRequest(req1);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
int32_t transactionId = bittorrent::getIntParam(data, 12);
tr.requestSent(now);
int64_t connectionId = 12345;
rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL(0, (int)rv);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
(int)bittorrent::getIntParam(data, 8));
transactionId = bittorrent::getIntParam(data, 12);
tr.requestSent(now);
rv = createErrorReply(data, sizeof(data), transactionId, "announce error");
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TRACKER, req1->error);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
}
}
void UDPTrackerClientTest::testTimeout()
{
ssize_t rv;
unsigned char data[100];
std::string remoteAddr;
uint16_t remotePort;
Timer now;
UDPTrackerClient tr;
{
SharedHandle<UDPTrackerRequest> req1
(createAnnounce("192.168.0.1", 6991, 0));
SharedHandle<UDPTrackerRequest> req2
(createAnnounce("192.168.0.1", 6991, 0));
tr.addRequest(req1);
tr.addRequest(req2);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
tr.requestSent(now);
now.advance(20);
// 15 seconds 1st stage timeout passed
tr.handleTimeout(now);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT_EQUAL((size_t)3, tr.getPendingRequests().size());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
// CONNECT request was inserted
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
tr.requestSent(now);
now.advance(65);
// 60 seconds 2nd stage timeout passed
tr.handleTimeout(now);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TIMEOUT, req1->error);
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TIMEOUT, req2->error);
}
{
SharedHandle<UDPTrackerRequest> req1
(createAnnounce("192.168.0.1", 6991, 0));
tr.addRequest(req1);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
(int)bittorrent::getIntParam(data, 8));
int32_t transactionId = bittorrent::getIntParam(data, 12);
tr.requestSent(now);
int64_t connectionId = 12345;
rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
CPPUNIT_ASSERT_EQUAL(0, (int)rv);
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
(int)bittorrent::getIntParam(data, 8));
tr.requestSent(now);
now.advance(20);
// 15 seconds 1st stage timeout passed
tr.handleTimeout(now);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT_EQUAL((size_t)1, tr.getPendingRequests().size());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
(int)bittorrent::getIntParam(data, 8));
tr.requestSent(now);
now.advance(65);
// 60 seconds 2nd stage timeout passed
tr.handleTimeout(now);
CPPUNIT_ASSERT(tr.getConnectRequests().empty());
CPPUNIT_ASSERT(tr.getPendingRequests().empty());
CPPUNIT_ASSERT(tr.getInflightRequests().empty());
CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state);
CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TIMEOUT, req1->error);
}
}
} // namespace aria2