diff --git a/src/ByteArrayDiskWriter.cc b/src/ByteArrayDiskWriter.cc index 51974003..bb81dcd1 100644 --- a/src/ByteArrayDiskWriter.cc +++ b/src/ByteArrayDiskWriter.cc @@ -34,10 +34,14 @@ /* copyright --> */ #include "ByteArrayDiskWriter.h" #include "A2STR.h" +#include "DlAbortEx.h" +#include "fmt.h" namespace aria2 { -ByteArrayDiskWriter::ByteArrayDiskWriter() {} +ByteArrayDiskWriter::ByteArrayDiskWriter(size_t maxLength) + : maxLength_(maxLength) +{} ByteArrayDiskWriter::~ByteArrayDiskWriter() {} @@ -62,6 +66,10 @@ void ByteArrayDiskWriter::openExistingFile(uint64_t totalLength) void ByteArrayDiskWriter::writeData(const unsigned char* data, size_t dataLength, off_t position) { + if(position+dataLength > maxLength_) { + throw DL_ABORT_EX(fmt("Maximum length(%lu) exceeded.", + static_cast(maxLength_))); + } uint64_t length = size(); if(length < (uint64_t)position) { buf_.seekp(length, std::ios::beg); diff --git a/src/ByteArrayDiskWriter.h b/src/ByteArrayDiskWriter.h index fee6ddfb..c51d8200 100644 --- a/src/ByteArrayDiskWriter.h +++ b/src/ByteArrayDiskWriter.h @@ -43,10 +43,10 @@ namespace aria2 { class ByteArrayDiskWriter : public DiskWriter { private: std::stringstream buf_; - + size_t maxLength_; void clear(); public: - ByteArrayDiskWriter(); + ByteArrayDiskWriter(size_t maxLength = 5*1024*1024); virtual ~ByteArrayDiskWriter(); virtual void initAndOpenFile(uint64_t totalLength = 0);