diff --git a/Source/Core/DiscIO/WIABlob.cpp b/Source/Core/DiscIO/WIABlob.cpp index 4d4b104242..999d3136ac 100644 --- a/Source/Core/DiscIO/WIABlob.cpp +++ b/Source/Core/DiscIO/WIABlob.cpp @@ -1040,7 +1040,7 @@ std::optional> WIARVZFileReader::Compress(Compressor* compr { if (compressor) { - if (!compressor->Start() || !compressor->Compress(data, size) || !compressor->End()) + if (!compressor->Start(size) || !compressor->Compress(data, size) || !compressor->End()) return std::nullopt; data = compressor->GetData(); @@ -1564,7 +1564,7 @@ WIARVZFileReader::ProcessAndCompress(CompressThreadState* state, CompressPa if (state->compressor) { - if (!state->compressor->Start()) + if (!state->compressor->Start(entry.exception_lists.size() + entry.main_data.size())) return ConversionResultCode::InternalError; } diff --git a/Source/Core/DiscIO/WIACompression.cpp b/Source/Core/DiscIO/WIACompression.cpp index 20d19c4877..4fdf6cd429 100644 --- a/Source/Core/DiscIO/WIACompression.cpp +++ b/Source/Core/DiscIO/WIACompression.cpp @@ -446,7 +446,7 @@ PurgeCompressor::PurgeCompressor() PurgeCompressor::~PurgeCompressor() = default; -bool PurgeCompressor::Start() +bool PurgeCompressor::Start(std::optional size) { m_buffer.clear(); m_bytes_written = 0; @@ -550,7 +550,7 @@ Bzip2Compressor::~Bzip2Compressor() BZ2_bzCompressEnd(&m_stream); } -bool Bzip2Compressor::Start() +bool Bzip2Compressor::Start(std::optional size) { ASSERT_MSG(DISCIO, m_stream.state == nullptr, "Called Bzip2Compressor::Start() twice without calling Bzip2Compressor::End()"); @@ -674,7 +674,7 @@ LZMACompressor::~LZMACompressor() lzma_end(&m_stream); } -bool LZMACompressor::Start() +bool LZMACompressor::Start(std::optional size) { if (m_initialization_failed) return false; @@ -745,8 +745,11 @@ ZstdCompressor::ZstdCompressor(int compression_level) { m_stream = ZSTD_createCStream(); - if (ZSTD_isError(ZSTD_CCtx_setParameter(m_stream, ZSTD_c_compressionLevel, compression_level))) + if (ZSTD_isError(ZSTD_CCtx_setParameter(m_stream, ZSTD_c_compressionLevel, compression_level)) || + ZSTD_isError(ZSTD_CCtx_setParameter(m_stream, ZSTD_c_contentSizeFlag, 0))) + { m_stream = nullptr; + } } ZstdCompressor::~ZstdCompressor() @@ -754,7 +757,7 @@ ZstdCompressor::~ZstdCompressor() ZSTD_freeCStream(m_stream); } -bool ZstdCompressor::Start() +bool ZstdCompressor::Start(std::optional size) { if (!m_stream) return false; @@ -762,7 +765,16 @@ bool ZstdCompressor::Start() m_buffer.clear(); m_out_buffer = {}; - return !ZSTD_isError(ZSTD_CCtx_reset(m_stream, ZSTD_reset_session_only)); + if (ZSTD_isError(ZSTD_CCtx_reset(m_stream, ZSTD_reset_session_only))) + return false; + + if (size) + { + if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(m_stream, *size))) + return false; + } + + return true; } bool ZstdCompressor::Compress(const u8* data, size_t size) diff --git a/Source/Core/DiscIO/WIACompression.h b/Source/Core/DiscIO/WIACompression.h index 37e8cf3dc3..2a6bbc9f8d 100644 --- a/Source/Core/DiscIO/WIACompression.h +++ b/Source/Core/DiscIO/WIACompression.h @@ -154,7 +154,7 @@ public: // First call Start, then AddDataOnlyForPurgeHashing/Compress any number of times, // then End, then GetData/GetSize any number of times. - virtual bool Start() = 0; + virtual bool Start(std::optional size) = 0; virtual bool AddPrecedingDataOnlyForPurgeHashing(const u8* data, size_t size) { return true; } virtual bool Compress(const u8* data, size_t size) = 0; virtual bool End() = 0; @@ -169,7 +169,7 @@ public: PurgeCompressor(); ~PurgeCompressor(); - bool Start() override; + bool Start(std::optional size) override; bool AddPrecedingDataOnlyForPurgeHashing(const u8* data, size_t size) override; bool Compress(const u8* data, size_t size) override; bool End() override; @@ -189,7 +189,7 @@ public: Bzip2Compressor(int compression_level); ~Bzip2Compressor(); - bool Start() override; + bool Start(std::optional size) override; bool Compress(const u8* data, size_t size) override; bool End() override; @@ -211,7 +211,7 @@ public: u8* compressor_data_size_out); ~LZMACompressor(); - bool Start() override; + bool Start(std::optional size) override; bool Compress(const u8* data, size_t size) override; bool End() override; @@ -234,7 +234,7 @@ public: ZstdCompressor(int compression_level); ~ZstdCompressor(); - bool Start() override; + bool Start(std::optional size) override; bool Compress(const u8* data, size_t size) override; bool End() override;