diff --git a/Source/Core/DiscIO/CMakeLists.txt b/Source/Core/DiscIO/CMakeLists.txt
index 46d66d217d..dcbab6cce4 100644
--- a/Source/Core/DiscIO/CMakeLists.txt
+++ b/Source/Core/DiscIO/CMakeLists.txt
@@ -54,6 +54,7 @@ target_link_libraries(discio
PUBLIC
BZip2::BZip2
LibLZMA::LibLZMA
+ zstd
PRIVATE
minizip
diff --git a/Source/Core/DiscIO/DiscIO.vcxproj b/Source/Core/DiscIO/DiscIO.vcxproj
index 7d94fcff27..2a60bfacee 100644
--- a/Source/Core/DiscIO/DiscIO.vcxproj
+++ b/Source/Core/DiscIO/DiscIO.vcxproj
@@ -118,6 +118,9 @@
{1d8c51d2-ffa4-418e-b183-9f42b6a6717e}
+
+ {1bea10f3-80ce-4bc4-9331-5769372cdf99}
+
diff --git a/Source/Core/DiscIO/WIABlob.cpp b/Source/Core/DiscIO/WIABlob.cpp
index 7414c48221..b3c9553d49 100644
--- a/Source/Core/DiscIO/WIABlob.cpp
+++ b/Source/Core/DiscIO/WIABlob.cpp
@@ -18,6 +18,7 @@
#include
#include
#include
+#include
#include "Common/Align.h"
#include "Common/Assert.h"
@@ -38,6 +39,24 @@
namespace DiscIO
{
+std::pair GetAllowedCompressionLevels(WIACompressionType compression_type)
+{
+ switch (compression_type)
+ {
+ case WIACompressionType::Bzip2:
+ case WIACompressionType::LZMA:
+ case WIACompressionType::LZMA2:
+ return {1, 9};
+ case WIACompressionType::Zstd:
+ // The actual minimum level can be gotten by calling ZSTD_minCLevel(). However, returning that
+ // would make the UI rather weird, because it is a negative number with very large magnitude.
+ // Note: Level 0 is a special number which means "default level" (level 3 as of this writing).
+ return {1, ZSTD_maxCLevel()};
+ default:
+ return {0, -1};
+ }
+}
+
WIAFileReader::WIAFileReader(File::IOFile file, const std::string& path)
: m_file(std::move(file)), m_encryption_cache(this)
{
@@ -110,9 +129,9 @@ bool WIAFileReader::Initialize(const std::string& path)
const u32 compression_type = Common::swap32(m_header_2.compression_type);
m_compression_type = static_cast(compression_type);
- if (m_compression_type > WIACompressionType::LZMA2)
+ if (m_compression_type > (m_rvz ? WIACompressionType::Zstd : WIACompressionType::LZMA2))
{
- ERROR_LOG(DISCIO, "Unsupported WIA compression type %u in %s", compression_type, path.c_str());
+ ERROR_LOG(DISCIO, "Unsupported compression type %u in %s", compression_type, path.c_str());
return false;
}
@@ -460,6 +479,9 @@ WIAFileReader::Chunk& WIAFileReader::ReadCompressedData(u64 offset_in_file, u64
decompressor = std::make_unique(true, m_header_2.compressor_data,
m_header_2.compressor_data_size);
break;
+ case WIACompressionType::Zstd:
+ decompressor = std::make_unique();
+ break;
}
const bool compressed_exception_lists = m_compression_type > WIACompressionType::Purge;
@@ -725,6 +747,34 @@ bool WIAFileReader::LZMADecompressor::Decompress(const DecompressionBuffer& in,
return result == LZMA_OK || result == LZMA_STREAM_END;
}
+WIAFileReader::ZstdDecompressor::ZstdDecompressor()
+{
+ m_stream = ZSTD_createDStream();
+}
+
+WIAFileReader::ZstdDecompressor::~ZstdDecompressor()
+{
+ ZSTD_freeDStream(m_stream);
+}
+
+bool WIAFileReader::ZstdDecompressor::Decompress(const DecompressionBuffer& in,
+ DecompressionBuffer* out, size_t* in_bytes_read)
+{
+ if (!m_stream)
+ return false;
+
+ ZSTD_inBuffer in_buffer{in.data.data(), in.bytes_written, *in_bytes_read};
+ ZSTD_outBuffer out_buffer{out->data.data(), out->data.size(), out->bytes_written};
+
+ const size_t result = ZSTD_decompressStream(m_stream, &out_buffer, &in_buffer);
+
+ *in_bytes_read = in_buffer.pos;
+ out->bytes_written = out_buffer.pos;
+
+ m_done = result == 0;
+ return !ZSTD_isError(result);
+}
+
WIAFileReader::Compressor::~Compressor() = default;
WIAFileReader::PurgeCompressor::PurgeCompressor()
@@ -1032,6 +1082,71 @@ size_t WIAFileReader::LZMACompressor::GetSize() const
return static_cast(m_stream.next_out - m_buffer.data());
}
+WIAFileReader::ZstdCompressor::ZstdCompressor(int compression_level)
+{
+ m_stream = ZSTD_createCStream();
+
+ if (ZSTD_isError(ZSTD_CCtx_setParameter(m_stream, ZSTD_c_compressionLevel, compression_level)))
+ m_stream = nullptr;
+}
+
+WIAFileReader::ZstdCompressor::~ZstdCompressor()
+{
+ ZSTD_freeCStream(m_stream);
+}
+
+bool WIAFileReader::ZstdCompressor::Start()
+{
+ if (!m_stream)
+ return false;
+
+ m_buffer.clear();
+ m_out_buffer = {};
+
+ return !ZSTD_isError(ZSTD_CCtx_reset(m_stream, ZSTD_reset_session_only));
+}
+
+bool WIAFileReader::ZstdCompressor::Compress(const u8* data, size_t size)
+{
+ ZSTD_inBuffer in_buffer{data, size, 0};
+
+ ExpandBuffer(size);
+
+ while (in_buffer.size != in_buffer.pos)
+ {
+ if (m_out_buffer.size == m_out_buffer.pos)
+ ExpandBuffer(0x100);
+
+ if (ZSTD_isError(ZSTD_compressStream(m_stream, &m_out_buffer, &in_buffer)))
+ return false;
+ }
+
+ return true;
+}
+
+bool WIAFileReader::ZstdCompressor::End()
+{
+ while (true)
+ {
+ if (m_out_buffer.size == m_out_buffer.pos)
+ ExpandBuffer(0x100);
+
+ const size_t result = ZSTD_endStream(m_stream, &m_out_buffer);
+ if (ZSTD_isError(result))
+ return false;
+ if (result == 0)
+ return true;
+ }
+}
+
+void WIAFileReader::ZstdCompressor::ExpandBuffer(size_t bytes_to_add)
+{
+ m_buffer.resize(m_buffer.size() + bytes_to_add);
+
+ m_out_buffer.dst = m_buffer.data();
+ m_out_buffer.size = m_buffer.size();
+}
+
WIAFileReader::Chunk::Chunk() = default;
WIAFileReader::Chunk::Chunk(File::IOFile* file, u64 offset_in_file, u64 compressed_size,
@@ -1138,8 +1253,14 @@ bool WIAFileReader::Chunk::Read(u64 offset, u64 size, u8* out_ptr)
if (m_out.bytes_written > expected_out_bytes)
return false; // Decompressed size is larger than expected
- if (m_out.bytes_written == expected_out_bytes && !m_decompressor->Done())
+ // The reason why we need the m_in.bytes_written == m_in.data.size() check as part of
+ // this conditional is because (for example) zstd can finish writing all data to m_out
+ // before becoming done if we've given it all input data except the checksum at the end.
+ if (m_out.bytes_written == expected_out_bytes && !m_decompressor->Done() &&
+ m_in.bytes_written == m_in.data.size())
+ {
return false; // Decompressed size is larger than expected
+ }
if (m_decompressor->Done() && m_in_bytes_read != m_in.data.size())
return false; // Compressed size is smaller than expected
@@ -1432,6 +1553,9 @@ void WIAFileReader::SetUpCompressor(std::unique_ptr* compressor,
compressor_data_size);
break;
}
+ case WIACompressionType::Zstd:
+ *compressor = std::make_unique(compression_level);
+ break;
}
}
diff --git a/Source/Core/DiscIO/WIABlob.h b/Source/Core/DiscIO/WIABlob.h
index 356d8a3edb..63580efe89 100644
--- a/Source/Core/DiscIO/WIABlob.h
+++ b/Source/Core/DiscIO/WIABlob.h
@@ -15,6 +15,7 @@
#include
#include
#include
+#include
#include "Common/CommonTypes.h"
#include "Common/File.h"
@@ -34,8 +35,11 @@ enum class WIACompressionType : u32
Bzip2 = 2,
LZMA = 3,
LZMA2 = 4,
+ Zstd = 5,
};
+std::pair GetAllowedCompressionLevels(WIACompressionType compression_type);
+
constexpr u32 WIA_MAGIC = 0x01414957; // "WIA\x1" (byteswapped to little endian)
constexpr u32 RVZ_MAGIC = 0x015A5652; // "RVZ\x1" (byteswapped to little endian)
@@ -250,6 +254,19 @@ private:
bool m_error_occurred = false;
};
+ class ZstdDecompressor final : public Decompressor
+ {
+ public:
+ ZstdDecompressor();
+ ~ZstdDecompressor();
+
+ bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
+ size_t* in_bytes_read) override;
+
+ private:
+ ZSTD_DStream* m_stream;
+ };
+
class Compressor
{
public:
@@ -332,6 +349,27 @@ private:
bool m_initialization_failed = false;
};
+ class ZstdCompressor final : public Compressor
+ {
+ public:
+ ZstdCompressor(int compression_level);
+ ~ZstdCompressor();
+
+ bool Start() override;
+ bool Compress(const u8* data, size_t size) override;
+ bool End() override;
+
+ const u8* GetData() const override { return m_buffer.data(); }
+ size_t GetSize() const override { return m_out_buffer.pos; }
+
+ private:
+ void ExpandBuffer(size_t bytes_to_add);
+
+ ZSTD_CStream* m_stream;
+ ZSTD_outBuffer m_out_buffer;
+ std::vector m_buffer;
+ };
+
class Chunk
{
public:
diff --git a/Source/Core/DolphinQt/ConvertDialog.cpp b/Source/Core/DolphinQt/ConvertDialog.cpp
index 120b557c6b..f739de801a 100644
--- a/Source/Core/DolphinQt/ConvertDialog.cpp
+++ b/Source/Core/DolphinQt/ConvertDialog.cpp
@@ -230,6 +230,11 @@ void ConvertDialog::OnFormatChanged()
AddToCompressionComboBox(slow.arg(QStringLiteral("bzip2")), DiscIO::WIACompressionType::Bzip2);
AddToCompressionComboBox(slow.arg(QStringLiteral("LZMA")), DiscIO::WIACompressionType::LZMA);
AddToCompressionComboBox(slow.arg(QStringLiteral("LZMA2")), DiscIO::WIACompressionType::LZMA2);
+ if (format == DiscIO::BlobType::RVZ)
+ {
+ AddToCompressionComboBox(QStringLiteral("Zstandard"), DiscIO::WIACompressionType::Zstd);
+ m_compression->setCurrentIndex(m_compression->count() - 1);
+ }
break;
}
@@ -246,19 +251,16 @@ void ConvertDialog::OnCompressionChanged()
{
m_compression_level->clear();
- switch (static_cast(m_compression->currentData().toInt()))
+ const auto compression_type =
+ static_cast(m_compression->currentData().toInt());
+
+ const std::pair range = DiscIO::GetAllowedCompressionLevels(compression_type);
+
+ for (int i = range.first; i <= range.second; ++i)
{
- case DiscIO::WIACompressionType::Bzip2:
- case DiscIO::WIACompressionType::LZMA:
- case DiscIO::WIACompressionType::LZMA2:
- for (int i = 1; i <= 9; ++i)
- AddToCompressionLevelComboBox(i);
-
- m_compression_level->setCurrentIndex(4);
-
- break;
- default:
- break;
+ AddToCompressionLevelComboBox(i);
+ if (i == 5)
+ m_compression_level->setCurrentIndex(m_compression_level->count() - 1);
}
m_compression_level->setEnabled(m_compression_level->count() > 1);