diff --git a/src/LpdMessageReceiver.cc b/src/LpdMessageReceiver.cc index 0d344e4e..787b76f9 100644 --- a/src/LpdMessageReceiver.cc +++ b/src/LpdMessageReceiver.cc @@ -80,21 +80,37 @@ bool LpdMessageReceiver::init(const std::string& localAddr) SharedHandle LpdMessageReceiver::receiveMessage() { SharedHandle msg; - try { + while(1) { unsigned char buf[200]; std::pair peerAddr; - ssize_t length = socket_->readDataFrom(buf, sizeof(buf), peerAddr); - if(length == 0) { + ssize_t length; + try { + length = socket_->readDataFrom(buf, sizeof(buf), peerAddr); + if(length == 0) { + return msg; + } + } catch(RecoverableException& e) { + A2_LOG_INFO_EX("Failed to receive LPD message.", e); return msg; } HttpHeaderProcessor proc(HttpHeaderProcessor::SERVER_PARSER); - if(!proc.parse(buf, length)) { - msg.reset(new LpdMessage()); - return msg; + try { + if(!proc.parse(buf, length)) { + // UDP packet must contain whole HTTP header block. + continue; + } + } catch(RecoverableException& e) { + A2_LOG_INFO_EX("Failed to parse LPD message.", e); + continue; } const SharedHandle& header = proc.getResult(); const std::string& infoHashString = header->find(HttpHeader::INFOHASH); - uint16_t port = header->findAsInt(HttpHeader::PORT); + uint32_t port = 0; + if(!util::parseUIntNoThrow(port, header->find(HttpHeader::PORT)) || + port > UINT16_MAX || port == 0) { + A2_LOG_INFO(fmt("Bad LPD port=%u", port)); + continue; + } A2_LOG_INFO(fmt("LPD message received infohash=%s, port=%u from %s", infoHashString.c_str(), port, @@ -102,11 +118,10 @@ SharedHandle LpdMessageReceiver::receiveMessage() std::string infoHash; if(infoHashString.size() != 40 || (infoHash = util::fromHex(infoHashString.begin(), - infoHashString.end())).empty() || - port == 0) { - A2_LOG_INFO(fmt("LPD bad request. infohash=%s", infoHashString.c_str())); - msg.reset(new LpdMessage()); - return msg; + infoHashString.end())).empty()) { + A2_LOG_INFO(fmt("LPD bad request. infohash=%s", + infoHashString.c_str())); + continue; } SharedHandle peer(new Peer(peerAddr.first, port, false)); if(util::inPrivateAddress(peerAddr.first)) { @@ -114,10 +129,6 @@ SharedHandle LpdMessageReceiver::receiveMessage() } msg.reset(new LpdMessage(peer, infoHash)); return msg; - } catch(RecoverableException& e) { - A2_LOG_INFO_EX("Failed to receive LPD message.", e); - msg.reset(new LpdMessage()); - return msg; } } diff --git a/src/LpdReceiveMessageCommand.cc b/src/LpdReceiveMessageCommand.cc index 0a7c507b..81ad11cc 100644 --- a/src/LpdReceiveMessageCommand.cc +++ b/src/LpdReceiveMessageCommand.cc @@ -77,10 +77,6 @@ bool LpdReceiveMessageCommand::execute() if(!m) { break; } - if(!m->peer) { - // bad message - continue; - } SharedHandle reg = e_->getBtRegistry(); SharedHandle dctx = reg->getDownloadContext(m->infoHash); if(!dctx) { diff --git a/test/LpdMessageReceiverTest.cc b/test/LpdMessageReceiverTest.cc index ed42bfb2..ad807b26 100644 --- a/test/LpdMessageReceiverTest.cc +++ b/test/LpdMessageReceiverTest.cc @@ -67,9 +67,7 @@ void LpdMessageReceiverTest::testReceiveMessage() rcv.getSocket()->isReadable(5); msg = rcv.receiveMessage(); - CPPUNIT_ASSERT(msg); - CPPUNIT_ASSERT(!msg->peer); - CPPUNIT_ASSERT(msg->infoHash.empty()); + CPPUNIT_ASSERT(!msg); // Bad port request = @@ -81,9 +79,7 @@ void LpdMessageReceiverTest::testReceiveMessage() rcv.getSocket()->isReadable(5); msg = rcv.receiveMessage(); - CPPUNIT_ASSERT(msg); - CPPUNIT_ASSERT(!msg->peer); - CPPUNIT_ASSERT(msg->infoHash.empty()); + CPPUNIT_ASSERT(!msg); // No data available msg = rcv.receiveMessage();