diff --git a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp index c4ad46ad..734aa2e4 100644 --- a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp +++ b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp @@ -3849,7 +3849,7 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, // Will also modify vid in case of an indexed draw src->add("VertexIn fetchInput(VERTEX_BUFFER_DEFINITIONS, thread uint& vid);" _CRLF); - functionType = "[[object, max_total_threads_per_threadgroup(MAX_THREADS_PER_THREADGROUP), max_total_threadgroups_per_mesh_grid(1)]]"; + functionType = "[[object, max_total_threads_per_threadgroup(VERTICES_PER_PRIMITIVE), max_total_threadgroups_per_mesh_grid(1)]]"; outputTypeName = "void"; } else @@ -3876,7 +3876,7 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, if (shader->shaderType == LatteConst::ShaderType::Vertex) { // Calculate the imaginary vertex id - src->add("uint vid = tig * PRIMITIVE_VERTEX_COUNT + tid;" _CRLF); + src->add("uint vid = tig * VERTICES_PER_PRIMITIVE + tid;" _CRLF); // TODO: don't hardcode the instance index src->add("uint iid = 0;" _CRLF); // Fetch the input @@ -4145,6 +4145,10 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, src->add("meshGridProperties.set_threadgroups_per_grid(uint3(1, 1, 1));" _CRLF); src->add("}" _CRLF); } + else if (shader->shaderType == LatteConst::ShaderType::Geometry) + { + src->add("mesh.set_primitive_count(vertexIndex / VERTICES_PER_PRIMITIVE);" _CRLF); + } } else { diff --git a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp index 95fd4cef..9f8b62ae 100644 --- a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp +++ b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp @@ -303,7 +303,7 @@ namespace LatteDecompiler src->addFmt("int4 passParameterSem{};" _CRLF, f); src->add("};" _CRLF _CRLF); src->add("struct ObjectPayload {" _CRLF); - src->add("VertexOut vertexOut[PRIMITIVE_VERTEX_COUNT];" _CRLF); + src->add("VertexOut vertexOut[VERTICES_PER_PRIMITIVE];" _CRLF); src->add("};" _CRLF _CRLF); } if (decompilerContext->shaderType == LatteConst::ShaderType::Geometry) @@ -325,8 +325,10 @@ namespace LatteDecompiler } src->add("};" _CRLF _CRLF); + const uint32 MAX_PRIMITIVE_COUNT = 8; + // Define the mesh shader output type - src->add("using MeshType = mesh;" _CRLF); + src->addFmt("using MeshType = mesh;" _CRLF, MAX_PRIMITIVE_COUNT, MAX_PRIMITIVE_COUNT); } } }