diff --git a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp index 02581b03..1b0f3f71 100644 --- a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp +++ b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSL.cpp @@ -3589,7 +3589,6 @@ void LatteDecompiler_emitClauseCodeMSL(LatteDecompilerShaderContext* shaderConte src->add("out.pointSize = supportBuffer.pointSize;" _CRLF); // Emit vertex (if the vertex index matches thread id) src->add("mesh.set_vertex(vertexIndex, out);" _CRLF); - src->add("mesh.set_index(vertexIndex, vertexIndex);" _CRLF); src->add("vertexIndex++;" _CRLF); // increment transform feedback pointer for (sint32 i = 0; i < LATTE_NUM_STREAMOUT_BUFFER; i++) @@ -3849,7 +3848,7 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, // Will also modify vid in case of an indexed draw src->add("VertexIn fetchInput(thread uint& vid VERTEX_BUFFER_DEFINITIONS);" _CRLF); - functionType = "[[object, max_total_threads_per_threadgroup(VERTICES_PER_PRIMITIVE), max_total_threadgroups_per_mesh_grid(1)]]"; + functionType = "[[object, max_total_threads_per_threadgroup(VERTICES_PER_VERTEX_PRIMITIVE), max_total_threadgroups_per_mesh_grid(1)]]"; outputTypeName = "void"; } else @@ -3876,7 +3875,7 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, if (shader->shaderType == LatteConst::ShaderType::Vertex) { // Calculate the imaginary vertex id - src->add("uint vid = tig * VERTICES_PER_PRIMITIVE + tid;" _CRLF); + src->add("uint vid = tig * VERTICES_PER_VERTEX_PRIMITIVE + tid;" _CRLF); // TODO: don't hardcode the instance index src->add("uint iid = 0;" _CRLF); // Fetch the input @@ -4145,7 +4144,27 @@ void LatteDecompiler_emitMSLShader(LatteDecompilerShaderContext* shaderContext, } else if (shader->shaderType == LatteConst::ShaderType::Geometry) { - src->add("mesh.set_primitive_count(vertexIndex / VERTICES_PER_PRIMITIVE);" _CRLF); + src->add("mesh.set_primitive_count(GET_PRIMITIVE_COUNT(vertexIndex));" _CRLF); + + // Set indices + if (shaderContext->contextRegisters[mmVGT_GS_OUT_PRIM_TYPE] == 1) // Line strip + { + src->add("for (uint8_t i = 0; i < GET_PRIMITIVE_COUNT(vertexIndex) * 2; i++) {" _CRLF); + src->add("mesh.set_index(i, (i 2 3) + i % 2);" _CRLF); + src->add("}" _CRLF); + } + else if (shaderContext->contextRegisters[mmVGT_GS_OUT_PRIM_TYPE] == 2) // Triangle strip + { + src->add("for (uint8_t i = 0; i < GET_PRIMITIVE_COUNT(vertexIndex) * 3; i++) {" _CRLF); + src->add("mesh.set_index(i, (i / 3) + i % 3);" _CRLF); + src->add("}" _CRLF); + } + else + { + src->add("for (uint8_t i = 0; i < vertexIndex; i++) {" _CRLF); + src->add("mesh.set_index(i, i);" _CRLF); + src->add("}" _CRLF); + } } } else diff --git a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp index b5e16ec7..20f75c95 100644 --- a/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp +++ b/src/Cafe/HW/Latte/LegacyShaderDecompiler/LatteDecompilerEmitMSLHeader.hpp @@ -1,5 +1,6 @@ #pragma once +#include "Common/precompiled.h" #include "HW/Latte/Core/LatteConst.h" namespace LatteDecompiler { @@ -307,7 +308,7 @@ namespace LatteDecompiler src->addFmt("int4 passParameterSem{};" _CRLF, f); src->add("};" _CRLF _CRLF); src->add("struct ObjectPayload {" _CRLF); - src->add("VertexOut vertexOut[VERTICES_PER_PRIMITIVE];" _CRLF); + src->add("VertexOut vertexOut[VERTICES_PER_VERTEX_PRIMITIVE];" _CRLF); src->add("};" _CRLF _CRLF); } if (decompilerContext->shaderType == LatteConst::ShaderType::Geometry) @@ -329,10 +330,10 @@ namespace LatteDecompiler } src->add("};" _CRLF _CRLF); - const uint32 MAX_PRIMITIVE_COUNT = 8; + const uint32 MAX_VERTEX_COUNT = 32; // Define the mesh shader output type - src->addFmt("using MeshType = mesh;" _CRLF, MAX_PRIMITIVE_COUNT, MAX_PRIMITIVE_COUNT); + src->addFmt("using MeshType = mesh;" _CRLF, MAX_VERTEX_COUNT, MAX_VERTEX_COUNT); } } } @@ -343,15 +344,46 @@ namespace LatteDecompiler if (decompilerContext->options->usesGeometryShader && (decompilerContext->shaderType == LatteConst::ShaderType::Vertex || decompilerContext->shaderType == LatteConst::ShaderType::Geometry)) { - src->add("#if PRIMITIVE_TYPE == point" _CRLF); - src->add("#define VERTICES_PER_PRIMITIVE 1" _CRLF); - src->add("#elif PRIMITIVE_TYPE == line" _CRLF); - src->add("#define VERTICES_PER_PRIMITIVE 2" _CRLF); - src->add("#elif PRIMITIVE_TYPE == triangle" _CRLF); - src->add("#define VERTICES_PER_PRIMITIVE 3" _CRLF); - src->add("#else" _CRLF); - src->add("#error unsupported primitive type" _CRLF); - src->add("#endif" _CRLF); + // TODO: make vsOutPrimType parth of the shader hash + LattePrimitiveMode vsOutPrimType = static_cast(decompilerContext->contextRegisters[mmVGT_PRIMITIVE_TYPE]); + uint32 gsOutPrimType = decompilerContext->contextRegisters[mmVGT_GS_OUT_PRIM_TYPE]; + + switch (vsOutPrimType) + { + case LattePrimitiveMode::POINTS: + src->add("#define VERTICES_PER_VERTEX_PRIMITIVE 1" _CRLF); + break; + case LattePrimitiveMode::LINES: + src->add("#define VERTICES_PER_VERTEX_PRIMITIVE 2" _CRLF); + break; + case LattePrimitiveMode::TRIANGLES: + src->add("#define VERTICES_PER_VERTEX_PRIMITIVE 3" _CRLF); + break; + default: + cemu_assert_suspicious(); + break; + } + if (decompilerContext->shaderType == LatteConst::ShaderType::Geometry) + { + switch (gsOutPrimType) + { + case 0: // Point + src->add("#define MTL_PRIMITIVE_TYPE point" _CRLF); + src->add("#define GET_PRIMITIVE_COUNT(vertexCount) (vertexCount / 1)" _CRLF); + break; + case 1: // Line strip + src->add("#define MTL_PRIMITIVE_TYPE line" _CRLF); + src->add("#define GET_PRIMITIVE_COUNT(vertexCount) (vertexCount - 1)" _CRLF); + break; + case 2: // Triangle strip + src->add("#define MTL_PRIMITIVE_TYPE triangle" _CRLF); + src->add("#define GET_PRIMITIVE_COUNT(vertexCount) (vertexCount - 2)" _CRLF); + break; + default: + cemu_assert_suspicious(); + break; + } + } } const bool dump_shaders_enabled = ActiveSettings::DumpShadersEnabled(); diff --git a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCache.cpp b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCache.cpp index d8d39b79..8e115e58 100644 --- a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCache.cpp +++ b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCache.cpp @@ -276,7 +276,6 @@ MTL::RenderPipelineState* MetalPipelineCache::GetMeshPipelineState(const LatteFe auto mtlMeshShader = static_cast(geometryShader->shader); auto mtlPixelShader = static_cast(pixelShader->shader); mtlObjectShader->CompileObjectFunction(lcr, fetchShader, vertexShader, hostIndexType); - mtlMeshShader->CompileMeshFunction(lcr, fetchShader); mtlPixelShader->CompileFragmentFunction(activeFBO); // Render pipeline state diff --git a/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp b/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp index 91df236c..89c9c2a3 100644 --- a/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp +++ b/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp @@ -1013,7 +1013,7 @@ void MetalRenderer::draw_execute(uint32 baseVertex, uint32 baseInstance, uint32 // Restride if (geometryShader) { - // Object shaders don't need restriding, since the attribute are fetched in the shader + // Object shaders don't need restriding, since the attributes are fetched in the shader buffer = m_memoryManager->GetBufferCache(); offset = m_state.m_vertexBuffers[i].offset; } diff --git a/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.cpp b/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.cpp index 883a85c6..dc2846ef 100644 --- a/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.cpp +++ b/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.cpp @@ -16,8 +16,15 @@ extern std::atomic_int g_compiled_shaders_async; RendererShaderMtl::RendererShaderMtl(MetalRenderer* mtlRenderer, ShaderType type, uint64 baseHash, uint64 auxHash, bool isGameShader, bool isGfxPackShader, const std::string& mslCode) : RendererShader(type, baseHash, auxHash, isGameShader, isGfxPackShader), m_mtlr{mtlRenderer} { - // TODO: don't compile just-in-time - m_mslCode = mslCode; + if (type == ShaderType::kGeometry) + { + Compile(mslCode); + } + else + { + // TODO: don't compile just-in-time + m_mslCode = mslCode; + } // Count shader compilation g_compiled_shaders_total++; @@ -35,25 +42,6 @@ void RendererShaderMtl::CompileObjectFunction(const LatteContextRegister& lcr, c std::string fullCode; - // Primitive type - const LattePrimitiveMode primitiveMode = static_cast(lcr.VGT_PRIMITIVE_TYPE.get_PRIMITIVE_MODE()); - fullCode += "#define PRIMITIVE_TYPE "; - switch (primitiveMode) - { - case LattePrimitiveMode::POINTS: - fullCode += "point"; - break; - case LattePrimitiveMode::LINES: - fullCode += "line"; - break; - case LattePrimitiveMode::TRIANGLES: - fullCode += "triangle"; - break; - default: - break; - } - fullCode += "\n"; - // Vertex buffers std::string vertexBufferDefinitions = "#define VERTEX_BUFFER_DEFINITIONS "; std::string vertexBuffers = "#define VERTEX_BUFFERS "; @@ -83,6 +71,10 @@ void RendererShaderMtl::CompileObjectFunction(const LatteContextRegister& lcr, c { std::optional fetchType; + uint32 bufferIndex = bufferGroup.attributeBufferIndex; + uint32 bufferBaseRegisterIndex = mmSQ_VTX_ATTRIBUTE_BLOCK_START + bufferIndex * 7; + uint32 bufferStride = (lcr.GetRawView()[bufferBaseRegisterIndex + 2] >> 11) & 0xFFFF; + for (sint32 j = 0; j < bufferGroup.attribCount; ++j) { auto& attr = bufferGroup.attrib[j]; @@ -149,7 +141,7 @@ void RendererShaderMtl::CompileObjectFunction(const LatteContextRegister& lcr, c inputFetchDefinition += "in.ATTRIBUTE_NAME" + std::to_string(semanticId) + " = "; inputFetchDefinition += "uint4(*(device " + formatName + "*)"; inputFetchDefinition += "(vertexBuffer" + std::to_string(attr.attributeBufferIndex); - inputFetchDefinition += " + vid + " + std::to_string(attr.offset) + ")"; + inputFetchDefinition += " + vid * " + std::to_string(bufferStride) + " + " + std::to_string(attr.offset) + ")"; for (uint8 i = 0; i < (4 - componentCount); i++) inputFetchDefinition += ", 0"; inputFetchDefinition += ");\n"; @@ -165,10 +157,6 @@ void RendererShaderMtl::CompileObjectFunction(const LatteContextRegister& lcr, c } } - uint32 bufferIndex = bufferGroup.attributeBufferIndex; - uint32 bufferBaseRegisterIndex = mmSQ_VTX_ATTRIBUTE_BLOCK_START + bufferIndex * 7; - uint32 bufferStride = (lcr.GetRawView()[bufferBaseRegisterIndex + 2] >> 11) & 0xFFFF; - vertexBufferDefinitions += ", device uchar* vertexBuffer" + std::to_string(bufferIndex) + " [[buffer(" + std::to_string(GET_MTL_VERTEX_BUFFER_INDEX(bufferIndex)) + ")]]"; vertexBuffers += ", vertexBuffer" + std::to_string(bufferIndex); } @@ -183,35 +171,6 @@ void RendererShaderMtl::CompileObjectFunction(const LatteContextRegister& lcr, c Compile(fullCode); } -void RendererShaderMtl::CompileMeshFunction(const LatteContextRegister& lcr, const LatteFetchShader* fetchShader) -{ - cemu_assert_debug(m_type == ShaderType::kGeometry); - - std::string fullCode; - - // Primitive type - const LattePrimitiveMode primitiveMode = static_cast(lcr.VGT_PRIMITIVE_TYPE.get_PRIMITIVE_MODE()); - fullCode += "#define PRIMITIVE_TYPE "; - switch (primitiveMode) - { - case LattePrimitiveMode::POINTS: - fullCode += "point"; - break; - case LattePrimitiveMode::LINES: - fullCode += "line"; - break; - case LattePrimitiveMode::TRIANGLES: - fullCode += "triangle"; - break; - default: - break; - } - fullCode += "\n"; - - fullCode += m_mslCode; - Compile(fullCode); -} - void RendererShaderMtl::CompileFragmentFunction(CachedFBOMtl* activeFBO) { cemu_assert_debug(m_type == ShaderType::kFragment); diff --git a/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.h b/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.h index 1a53313a..e21db55e 100644 --- a/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.h +++ b/src/Cafe/HW/Latte/Renderer/Metal/RendererShaderMtl.h @@ -27,7 +27,6 @@ public: } void CompileObjectFunction(const LatteContextRegister& lcr, const LatteFetchShader* fetchShader, const LatteDecompilerShader* vertexShader, Renderer::INDEX_TYPE hostIndexType); - void CompileMeshFunction(const LatteContextRegister& lcr, const LatteFetchShader* fetchShader); void CompileFragmentFunction(CachedFBOMtl* activeFBO); MTL::Function* GetFunction() const