Some more cleanup

This commit is contained in:
Maschell 2024-03-16 19:52:37 +01:00
parent d4984dafbe
commit 38355c1c3e
2 changed files with 104 additions and 98 deletions

View File

@ -245,109 +245,87 @@ TcpReceiver::eLoadResults TcpReceiver::loadBinary(void *data, uint32_t fileSize)
return UNSUPPORTED_FORMAT; return UNSUPPORTED_FORMAT;
} }
TcpReceiver::eLoadResults TcpReceiver::uncompressIfNeeded(const uint8_t *haxx, uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr<uint8_t> &&in_data, std::unique_ptr<uint8_t> &out_data, uint32_t &fileSizeOut) { std::unique_ptr<uint8_t[]> TcpReceiver::uncompressData(uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr<uint8_t[]> &&inData, uint32_t &fileSizeOut, eLoadResults &err) {
// Do we need to unzip this thing? std::unique_ptr<uint8_t[]> inflatedData;
if (haxx[4] > 0 || haxx[5] > 4) { uint8_t *in_data_raw = inData.get();
std::unique_ptr<uint8_t> inflatedData; // We need to unzip...
uint8_t *in_data_raw = in_data.get(); if (in_data_raw[0] == 'P' && in_data_raw[1] == 'K' && in_data_raw[2] == 0x03 && in_data_raw[3] == 0x04) {
// We need to unzip... // Section is compressed, inflate
if (in_data_raw[0] == 'P' && in_data_raw[1] == 'K' && in_data_raw[2] == 0x03 && in_data_raw[3] == 0x04) { inflatedData = make_unique_nothrow<uint8_t[]>(fileSizeUnc);
// Section is compressed, inflate if (!inflatedData) {
inflatedData = make_unique_nothrow<uint8_t>(fileSizeUnc); DEBUG_FUNCTION_LINE_ERR("malloc failed");
if (!inflatedData) { err = NOT_ENOUGH_MEMORY;
DEBUG_FUNCTION_LINE_ERR("malloc failed"); return {};
return NOT_ENOUGH_MEMORY;
}
int32_t ret;
z_stream s = {};
s.zalloc = Z_NULL;
s.zfree = Z_NULL;
s.opaque = Z_NULL;
ret = inflateInit(&s);
if (ret != Z_OK) {
DEBUG_FUNCTION_LINE_ERR("inflateInit failed %i", ret);
return FILE_UNCOMPRESS_ERROR;
}
s.avail_in = fileSize;
s.next_in = (Bytef *) inflatedData.get();
s.avail_out = fileSizeUnc;
s.next_out = (Bytef *) inflatedData.get();
ret = inflate(&s, Z_FINISH);
if (ret != Z_OK && ret != Z_STREAM_END) {
DEBUG_FUNCTION_LINE_ERR("inflate failed %i", ret);
return FILE_UNCOMPRESS_ERROR;
}
inflateEnd(&s);
fileSizeOut = fileSizeUnc;
out_data = std::move(inflatedData);
return SUCCESS;
} else {
// Section is compressed, inflate
inflatedData = make_unique_nothrow<uint8_t>(fileSizeUnc);
if (!inflatedData) {
DEBUG_FUNCTION_LINE_ERR("malloc failed");
return NOT_ENOUGH_MEMORY;
}
uLongf f = fileSizeUnc;
int32_t result = uncompress((Bytef *) inflatedData.get(), &f, (Bytef *) in_data_raw, fileSize);
if (result != Z_OK) {
DEBUG_FUNCTION_LINE_ERR("uncompress failed %i", result);
return FILE_UNCOMPRESS_ERROR;
}
fileSizeUnc = f;
fileSizeOut = fileSizeUnc;
out_data = std::move(inflatedData);
return SUCCESS;
} }
int32_t ret;
z_stream s = {};
s.zalloc = Z_NULL;
s.zfree = Z_NULL;
s.opaque = Z_NULL;
ret = inflateInit(&s);
if (ret != Z_OK) {
DEBUG_FUNCTION_LINE_ERR("inflateInit failed %i", ret);
err = FILE_UNCOMPRESS_ERROR;
return {};
}
s.avail_in = fileSize;
s.next_in = (Bytef *) inflatedData.get();
s.avail_out = fileSizeUnc;
s.next_out = (Bytef *) inflatedData.get();
ret = inflate(&s, Z_FINISH);
if (ret != Z_OK && ret != Z_STREAM_END) {
DEBUG_FUNCTION_LINE_ERR("inflate failed %i", ret);
err = FILE_UNCOMPRESS_ERROR;
return {};
}
inflateEnd(&s);
} else {
// Section is compressed, inflate
inflatedData = make_unique_nothrow<uint8_t[]>(fileSizeUnc);
if (!inflatedData) {
DEBUG_FUNCTION_LINE_ERR("malloc failed");
err = NOT_ENOUGH_MEMORY;
return {};
}
uLongf f = fileSizeUnc;
int32_t result = uncompress((Bytef *) inflatedData.get(), &f, (Bytef *) in_data_raw, fileSize);
if (result != Z_OK) {
DEBUG_FUNCTION_LINE_ERR("uncompress failed %i", result);
err = FILE_UNCOMPRESS_ERROR;
return {};
}
fileSizeUnc = f;
} }
fileSizeOut = fileSize;
out_data = std::move(in_data); fileSizeOut = fileSizeUnc;
return SUCCESS; err = SUCCESS;
return inflatedData;
} }
TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32_t ipAddress) { std::unique_ptr<uint8_t[]> TcpReceiver::receiveData(int32_t clientSocket, uint32_t fileSize, eLoadResults &err) {
DEBUG_FUNCTION_LINE("Loading file from ip %08X", ipAddress);
uint32_t fileSize = 0;
uint32_t fileSizeUnc = 0;
unsigned char haxx[8] = {};
//skip haxx
if (recvwait(clientSocket, haxx, sizeof(haxx)) != 0) {
return RECV_ERROR;
}
if (recvwait(clientSocket, (unsigned char *) &fileSize, sizeof(fileSize)) != 0) {
return RECV_ERROR;
}
if (haxx[4] > 0 || haxx[5] > 4) {
if (recvwait(clientSocket, (unsigned char *) &fileSizeUnc, sizeof(fileSizeUnc)) != 0) { // Compressed protocol, read another 4 bytes
return RECV_ERROR;
}
}
uint32_t bytesRead = 0; uint32_t bytesRead = 0;
auto dataOut = make_unique_nothrow<uint8_t[]>(fileSize);
auto receivedData = make_unique_nothrow<uint8_t>(fileSize); if (!dataOut) {
if (!receivedData) { err = NOT_ENOUGH_MEMORY;
return NOT_ENOUGH_MEMORY; return {};
} }
// Copy rpl in memory // Copy binary in memory
while (bytesRead < fileSize) { while (bytesRead < fileSize) {
uint32_t blockSize = 0x1000; uint32_t blockSize = 0x1000;
if (blockSize > (fileSize - bytesRead)) if (blockSize > (fileSize - bytesRead))
blockSize = fileSize - bytesRead; blockSize = fileSize - bytesRead;
int32_t ret = recv(clientSocket, receivedData.get() + bytesRead, blockSize, 0); int32_t ret = recv(clientSocket, dataOut.get() + bytesRead, blockSize, 0);
if (ret <= 0) { if (ret <= 0) {
DEBUG_FUNCTION_LINE_ERR("Failed to receive file"); DEBUG_FUNCTION_LINE_ERR("Failed to receive file");
break; break;
@ -358,15 +336,42 @@ TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32
if (bytesRead != fileSize) { if (bytesRead != fileSize) {
DEBUG_FUNCTION_LINE_ERR("File loading not finished, %i of %i bytes received", bytesRead, fileSize); DEBUG_FUNCTION_LINE_ERR("File loading not finished, %i of %i bytes received", bytesRead, fileSize);
err = RECV_ERROR;
return {};
}
err = SUCCESS;
return dataOut;
}
TcpReceiver::eLoadResults TcpReceiver::loadToMemory(int32_t clientSocket, uint32_t ipAddress) {
DEBUG_FUNCTION_LINE("Loading file from ip %08X", ipAddress);
uint32_t fileSize = 0;
uint32_t fileSizeUnc = 0;
uint8_t haxx[8] = {};
// read header
if (recvwait(clientSocket, haxx, sizeof(haxx)) != 0) {
return RECV_ERROR; return RECV_ERROR;
} }
if (recvwait(clientSocket, (void *) &fileSize, sizeof(fileSize)) != 0) {
return RECV_ERROR;
std::unique_ptr<uint8_t> finalData; }
eLoadResults err; bool compressedData = (haxx[4] > 0 || haxx[5] > 4);
if ((err = uncompressIfNeeded(haxx, fileSize, fileSizeUnc, std::move(receivedData), finalData, fileSize)) != SUCCESS) { if (compressedData) {
if (recvwait(clientSocket, (void *) &fileSizeUnc, sizeof(fileSizeUnc)) != 0) { // Compressed protocol, read another 4 bytes
return RECV_ERROR;
}
}
TcpReceiver::eLoadResults err = UNSUPPORTED_FORMAT;
auto receivedData = receiveData(clientSocket, fileSize, err);
if (err != SUCCESS) {
return err; return err;
} else if (compressedData) {
receivedData = uncompressData(fileSize, fileSizeUnc, std::move(receivedData), fileSize, err);
if (!receivedData || err != SUCCESS) {
return err;
}
} }
return loadBinary(finalData.get(), fileSize); return loadBinary(receivedData.get(), fileSize);
} }

View File

@ -36,7 +36,8 @@ private:
static TcpReceiver::eLoadResults tryLoadRPX(uint8_t *data, uint32_t fileSize, std::string &loadedPathOut); static TcpReceiver::eLoadResults tryLoadRPX(uint8_t *data, uint32_t fileSize, std::string &loadedPathOut);
static TcpReceiver::eLoadResults tryLoadWPS(uint8_t *data, uint32_t fileSize); static TcpReceiver::eLoadResults tryLoadWPS(uint8_t *data, uint32_t fileSize);
static TcpReceiver::eLoadResults loadBinary(void *data, uint32_t fileSize); static TcpReceiver::eLoadResults loadBinary(void *data, uint32_t fileSize);
static TcpReceiver::eLoadResults uncompressIfNeeded(const uint8_t *haxx, uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr<uint8_t> &&in_data, std::unique_ptr<uint8_t> &out_data, uint32_t &fileSizeOut); static std::unique_ptr<uint8_t[]> receiveData(int32_t clientSocket, uint32_t fileSize, eLoadResults &err);
static std::unique_ptr<uint8_t[]> uncompressData(uint32_t fileSize, uint32_t fileSizeUnc, std::unique_ptr<uint8_t[]> &&in_out_data, uint32_t &fileSizeOut, eLoadResults &err);
bool exitRequested; bool exitRequested;
int32_t serverPort; int32_t serverPort;