From 38355c1c3e028da0c9fe9c422b146c483046852f Mon Sep 17 00:00:00 2001 From: Maschell Date: Sat, 16 Mar 2024 19:52:37 +0100 Subject: [PATCH] Some more cleanup --- src/utils/TcpReceiver.cpp | 199 +++++++++++++++++++------------------- src/utils/TcpReceiver.h | 3 +- 2 files changed, 104 insertions(+), 98 deletions(-) diff --git a/src/utils/TcpReceiver.cpp b/src/utils/TcpReceiver.cpp index 152da3c..f108a00 100644 --- a/src/utils/TcpReceiver.cpp +++ b/src/utils/TcpReceiver.cpp @@ -245,109 +245,87 @@ TcpReceiver::eLoadResults TcpReceiver::loadBinary(void *data, uint32_t fileSize) return UNSUPPORTED_FORMAT; } -TcpReceiver::eLoadResults TcpReceiver::uncompressIfNeeded(const uint8_t *haxx, uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr &&in_data, std::unique_ptr &out_data, uint32_t &fileSizeOut) { - // Do we need to unzip this thing? - if (haxx[4] > 0 || haxx[5] > 4) { - std::unique_ptr inflatedData; - uint8_t *in_data_raw = in_data.get(); - // We need to unzip... - if (in_data_raw[0] == 'P' && in_data_raw[1] == 'K' && in_data_raw[2] == 0x03 && in_data_raw[3] == 0x04) { - // Section is compressed, inflate - inflatedData = make_unique_nothrow(fileSizeUnc); - if (!inflatedData) { - DEBUG_FUNCTION_LINE_ERR("malloc failed"); - return NOT_ENOUGH_MEMORY; - } - - int32_t ret; - z_stream s = {}; - - s.zalloc = Z_NULL; - s.zfree = Z_NULL; - s.opaque = Z_NULL; - - ret = inflateInit(&s); - if (ret != Z_OK) { - DEBUG_FUNCTION_LINE_ERR("inflateInit failed %i", ret); - return FILE_UNCOMPRESS_ERROR; - } - - s.avail_in = fileSize; - s.next_in = (Bytef *) inflatedData.get(); - - s.avail_out = fileSizeUnc; - s.next_out = (Bytef *) inflatedData.get(); - - ret = inflate(&s, Z_FINISH); - if (ret != Z_OK && ret != Z_STREAM_END) { - DEBUG_FUNCTION_LINE_ERR("inflate failed %i", ret); - return FILE_UNCOMPRESS_ERROR; - } - - inflateEnd(&s); - fileSizeOut = fileSizeUnc; - out_data = std::move(inflatedData); - return SUCCESS; - } else { - // Section is compressed, inflate - inflatedData = make_unique_nothrow(fileSizeUnc); - if (!inflatedData) { - DEBUG_FUNCTION_LINE_ERR("malloc failed"); - return NOT_ENOUGH_MEMORY; - } - - uLongf f = fileSizeUnc; - int32_t result = uncompress((Bytef *) inflatedData.get(), &f, (Bytef *) in_data_raw, fileSize); - if (result != Z_OK) { - DEBUG_FUNCTION_LINE_ERR("uncompress failed %i", result); - return FILE_UNCOMPRESS_ERROR; - } - - fileSizeUnc = f; - fileSizeOut = fileSizeUnc; - out_data = std::move(inflatedData); - return SUCCESS; +std::unique_ptr TcpReceiver::uncompressData(uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr &&inData, uint32_t &fileSizeOut, eLoadResults &err) { + std::unique_ptr inflatedData; + uint8_t *in_data_raw = inData.get(); + // We need to unzip... + if (in_data_raw[0] == 'P' && in_data_raw[1] == 'K' && in_data_raw[2] == 0x03 && in_data_raw[3] == 0x04) { + // Section is compressed, inflate + inflatedData = make_unique_nothrow(fileSizeUnc); + if (!inflatedData) { + DEBUG_FUNCTION_LINE_ERR("malloc failed"); + err = NOT_ENOUGH_MEMORY; + return {}; } + + int32_t ret; + z_stream s = {}; + + s.zalloc = Z_NULL; + s.zfree = Z_NULL; + s.opaque = Z_NULL; + + ret = inflateInit(&s); + if (ret != Z_OK) { + DEBUG_FUNCTION_LINE_ERR("inflateInit failed %i", ret); + err = FILE_UNCOMPRESS_ERROR; + return {}; + } + + s.avail_in = fileSize; + s.next_in = (Bytef *) inflatedData.get(); + + s.avail_out = fileSizeUnc; + s.next_out = (Bytef *) inflatedData.get(); + + ret = inflate(&s, Z_FINISH); + if (ret != Z_OK && ret != Z_STREAM_END) { + DEBUG_FUNCTION_LINE_ERR("inflate failed %i", ret); + err = FILE_UNCOMPRESS_ERROR; + return {}; + } + + inflateEnd(&s); + } else { + // Section is compressed, inflate + inflatedData = make_unique_nothrow(fileSizeUnc); + if (!inflatedData) { + DEBUG_FUNCTION_LINE_ERR("malloc failed"); + err = NOT_ENOUGH_MEMORY; + return {}; + } + + uLongf f = fileSizeUnc; + int32_t result = uncompress((Bytef *) inflatedData.get(), &f, (Bytef *) in_data_raw, fileSize); + if (result != Z_OK) { + DEBUG_FUNCTION_LINE_ERR("uncompress failed %i", result); + err = FILE_UNCOMPRESS_ERROR; + return {}; + } + + fileSizeUnc = f; } - fileSizeOut = fileSize; - out_data = std::move(in_data); - return SUCCESS; + + fileSizeOut = fileSizeUnc; + err = SUCCESS; + return inflatedData; } -TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32_t ipAddress) { - DEBUG_FUNCTION_LINE("Loading file from ip %08X", ipAddress); - - uint32_t fileSize = 0; - uint32_t fileSizeUnc = 0; - unsigned char haxx[8] = {}; - //skip haxx - if (recvwait(clientSocket, haxx, sizeof(haxx)) != 0) { - return RECV_ERROR; - } - if (recvwait(clientSocket, (unsigned char *) &fileSize, sizeof(fileSize)) != 0) { - return RECV_ERROR; - } - - if (haxx[4] > 0 || haxx[5] > 4) { - if (recvwait(clientSocket, (unsigned char *) &fileSizeUnc, sizeof(fileSizeUnc)) != 0) { // Compressed protocol, read another 4 bytes - return RECV_ERROR; - } - } - +std::unique_ptr TcpReceiver::receiveData(int32_t clientSocket, uint32_t fileSize, eLoadResults &err) { uint32_t bytesRead = 0; - - auto receivedData = make_unique_nothrow(fileSize); - if (!receivedData) { - return NOT_ENOUGH_MEMORY; + auto dataOut = make_unique_nothrow(fileSize); + if (!dataOut) { + err = NOT_ENOUGH_MEMORY; + return {}; } - // Copy rpl in memory + // Copy binary in memory while (bytesRead < fileSize) { uint32_t blockSize = 0x1000; if (blockSize > (fileSize - bytesRead)) blockSize = fileSize - bytesRead; - int32_t ret = recv(clientSocket, receivedData.get() + bytesRead, blockSize, 0); + int32_t ret = recv(clientSocket, dataOut.get() + bytesRead, blockSize, 0); if (ret <= 0) { DEBUG_FUNCTION_LINE_ERR("Failed to receive file"); break; @@ -358,15 +336,42 @@ TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32 if (bytesRead != fileSize) { DEBUG_FUNCTION_LINE_ERR("File loading not finished, %i of %i bytes received", bytesRead, fileSize); + err = RECV_ERROR; + return {}; + } + err = SUCCESS; + return dataOut; +} + +TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32_t ipAddress) { + DEBUG_FUNCTION_LINE("Loading file from ip %08X", ipAddress); + + uint32_t fileSize = 0; + uint32_t fileSizeUnc = 0; + uint8_t haxx[8] = {}; + // read header + if (recvwait(clientSocket, haxx, sizeof(haxx)) != 0) { return RECV_ERROR; } - - - std::unique_ptr finalData; - eLoadResults err; - if ((err = uncompressIfNeeded(haxx, fileSize, fileSizeUnc, std::move(receivedData), finalData, fileSize)) != SUCCESS) { + if (recvwait(clientSocket, (void *) &fileSize, sizeof(fileSize)) != 0) { + return RECV_ERROR; + } + bool compressedData = (haxx[4] > 0 || haxx[5] > 4); + if (compressedData) { + if (recvwait(clientSocket, (void *) &fileSizeUnc, sizeof(fileSizeUnc)) != 0) { // Compressed protocol, read another 4 bytes + return RECV_ERROR; + } + } + TcpReceiver::eLoadResults err = UNSUPPORTED_FORMAT; + auto receivedData = receiveData(clientSocket, fileSize, err); + if (err != SUCCESS) { return err; + } else if (compressedData) { + receivedData = uncompressData(fileSize, fileSizeUnc, std::move(receivedData), fileSize, err); + if (!receivedData || err != SUCCESS) { + return err; + } } - return loadBinary(finalData.get(), fileSize); + return loadBinary(receivedData.get(), fileSize); } diff --git a/src/utils/TcpReceiver.h b/src/utils/TcpReceiver.h index 29e42ae..df485fb 100644 --- a/src/utils/TcpReceiver.h +++ b/src/utils/TcpReceiver.h @@ -36,7 +36,8 @@ private: static TcpReceiver::eLoadResults tryLoadRPX(uint8_t *data, uint32_t fileSize, std::string &loadedPathOut); static TcpReceiver::eLoadResults tryLoadWPS(uint8_t *data, uint32_t fileSize); static TcpReceiver::eLoadResults loadBinary(void *data, uint32_t fileSize); - static TcpReceiver::eLoadResults uncompressIfNeeded(const uint8_t *haxx, uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr &&in_data, std::unique_ptr &out_data, uint32_t &fileSizeOut); + static std::unique_ptr receiveData(int32_t clientSocket, uint32_t fileSize, eLoadResults &err); + static std::unique_ptr uncompressData(uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr &&in_out_data, uint32_t &fileSizeOut, eLoadResults &err); bool exitRequested; int32_t serverPort;