diff --git a/src/export.cpp b/src/export.cpp index b242325..ca69bc4 100644 --- a/src/export.cpp +++ b/src/export.cpp @@ -1,20 +1,17 @@ #include "FileUtils.h" #include "utils/FileReader.h" #include "utils/FileReaderCompressed.h" +#include "utils/utils.h" #include -#include #include #include -std::vector openFiles; +std::forward_list> openFiles; std::map mountedWUHB; std::mutex mutex; void WUHBUtils_CleanUp() { std::lock_guard lock(mutex); - for (auto &file : openFiles) { - delete file; - } openFiles.clear(); for (const auto &[name, path] : mountedWUHB) { @@ -69,22 +66,24 @@ WUHBUtilsApiErrorType WUU_FileOpen(const char *name, uint32_t *outHandle) { return WUHB_UTILS_API_ERROR_INVALID_ARG; } std::lock_guard lock(mutex); - FileReader *reader; + std::unique_ptr reader; std::string path = std::string(name); std::string pathGZ = path + ".gz"; if (CheckFile(path.c_str())) { - reader = new (std::nothrow) FileReader(path); + reader = make_unique_nothrow(path); } else if (CheckFile(pathGZ.c_str())) { - reader = new (std::nothrow) FileReaderCompressed(pathGZ); + reader = make_unique_nothrow(pathGZ); } else { return WUHB_UTILS_API_ERROR_FILE_NOT_FOUND; } - if (reader == nullptr) { + + if (!reader || !reader->isReady()) { return WUHB_UTILS_API_ERROR_NO_MEMORY; } - openFiles.push_back(reader); - *outHandle = (uint32_t) reader; + *outHandle = reader->getHandle(); + openFiles.push_front(std::move(reader)); + return WUHB_UTILS_API_ERROR_NONE; } @@ -93,43 +92,22 @@ WUHBUtilsApiErrorType WUU_FileRead(uint32_t handle, uint8_t *buffer, uint32_t si return WUHB_UTILS_API_ERROR_INVALID_ARG; } std::lock_guard lock(mutex); - auto found = false; - FileReader *reader; - for (auto &cur : openFiles) { - if ((uint32_t) cur == handle) { - found = true; - reader = cur; - break; + for (auto &reader : openFiles) { + if ((uint32_t) reader.get() == (uint32_t) handle) { + *outRes = (int32_t) reader->read(buffer, size); + return WUHB_UTILS_API_ERROR_NONE; } } - if (!found) { - return WUHB_UTILS_API_ERROR_FILE_HANDLE_NOT_FOUND; - } - *outRes = (int32_t) reader->read(buffer, size); - - return WUHB_UTILS_API_ERROR_NONE; + return WUHB_UTILS_API_ERROR_FILE_HANDLE_NOT_FOUND; } WUHBUtilsApiErrorType WUU_FileClose(uint32_t handle) { - std::lock_guard lock(mutex); - auto count = 0; - auto found = false; - FileReader *reader; - for (auto &cur : openFiles) { - if ((uint32_t) cur == handle) { - found = true; - reader = cur; - break; - } - count++; + if (remove_locked_first_if(mutex, openFiles, [handle](auto &cur) { return cur->getHandle() == handle; })) { + return WUHB_UTILS_API_ERROR_NONE; } - if (!found) { - return WUHB_UTILS_API_ERROR_FILE_HANDLE_NOT_FOUND; - } - openFiles.erase(openFiles.begin() + count); - delete reader; - return WUHB_UTILS_API_ERROR_NONE; + + return WUHB_UTILS_API_ERROR_FILE_HANDLE_NOT_FOUND; } WUHBUtilsApiErrorType WUU_FileExists(const char *name, int32_t *outRes) { diff --git a/src/utils/FileReader.cpp b/src/utils/FileReader.cpp index 3754dad..4062e26 100644 --- a/src/utils/FileReader.cpp +++ b/src/utils/FileReader.cpp @@ -24,6 +24,14 @@ int64_t FileReader::read(uint8_t *buffer, uint32_t size) { return -2; } +FileReader::FileReader(uint8_t *buffer, uint32_t size) { + this->input_buffer = buffer; + this->input_size = size; + this->input_pos = 0; + this->isReadFromBuffer = true; + this->isReadFromFile = false; +} + FileReader::FileReader(std::string &path) { int fd; if ((fd = open(path.c_str(), O_RDONLY)) >= 0) { @@ -41,10 +49,7 @@ FileReader::~FileReader() { } } -FileReader::FileReader(uint8_t *buffer, uint32_t size) { - this->input_buffer = buffer; - this->input_size = size; - this->input_pos = 0; - this->isReadFromBuffer = true; - this->isReadFromFile = false; -} + +bool FileReader::isReady() { + return this->isReadFromFile || this->isReadFromBuffer; +} \ No newline at end of file diff --git a/src/utils/FileReader.h b/src/utils/FileReader.h index b6c019a..7748333 100644 --- a/src/utils/FileReader.h +++ b/src/utils/FileReader.h @@ -17,6 +17,12 @@ public: virtual int64_t read(uint8_t *buffer, uint32_t size); + virtual bool isReady(); + + virtual uint32_t getHandle() { + return reinterpret_cast(this); + } + private: bool isReadFromBuffer = false; uint8_t *input_buffer = nullptr; diff --git a/src/utils/utils.h b/src/utils/utils.h new file mode 100644 index 0000000..2eb7bf8 --- /dev/null +++ b/src/utils/utils.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +template +std::unique_ptr make_unique_nothrow(Args &&...args) noexcept(noexcept(T(std::forward(args)...))) { + return std::unique_ptr(new (std::nothrow) T(std::forward(args)...)); +} + +template +std::shared_ptr make_shared_nothrow(Args &&...args) noexcept(noexcept(T(std::forward(args)...))) { + return std::shared_ptr(new (std::nothrow) T(std::forward(args)...)); +} + +template +bool remove_locked_first_if(std::mutex &mutex, std::forward_list &list, Predicate pred) { + std::lock_guard lock(mutex); + auto oit = list.before_begin(), it = std::next(oit); + while (it != list.end()) { + if (pred(*it)) { + list.erase_after(oit); + return true; + } + oit = it++; + } + return false; +}