From 15372dc835a2c387f8f566fa86d6cc60b2534a7d Mon Sep 17 00:00:00 2001
From: iwubcode <iwubcode@users.noreply.github.com>
Date: Wed, 12 Feb 2025 20:18:20 -0600
Subject: [PATCH] VideoCommon: move lighting shader logic to callable functions

---
 Source/Core/VideoCommon/LightingShaderGen.cpp | 37 ++++++++++---------
 Source/Core/VideoCommon/LightingShaderGen.h   |  3 +-
 Source/Core/VideoCommon/PixelShaderGen.cpp    | 14 ++++---
 Source/Core/VideoCommon/VertexShaderGen.cpp   |  8 +++-
 4 files changed, 37 insertions(+), 25 deletions(-)

diff --git a/Source/Core/VideoCommon/LightingShaderGen.cpp b/Source/Core/VideoCommon/LightingShaderGen.cpp
index 4fb2c98ebd..130fb6f46d 100644
--- a/Source/Core/VideoCommon/LightingShaderGen.cpp
+++ b/Source/Core/VideoCommon/LightingShaderGen.cpp
@@ -79,31 +79,34 @@ static void GenerateLightShader(ShaderCode& object, const LightingUidData& uid_d
 // vertex shader
 // lights/colors
 // materials name is I_MATERIALS in vs and I_PMATERIALS in ps
-// inColorName is color in vs and colors_ in ps
-// dest is o.colors_ in vs and colors_ in ps
-void GenerateLightingShaderCode(ShaderCode& object, const LightingUidData& uid_data,
-                                std::string_view in_color_name, std::string_view dest)
+void GenerateLightingShaderHeader(ShaderCode& object, const LightingUidData& uid_data)
 {
   for (u32 j = 0; j < NUM_XF_COLOR_CHANNELS; j++)
   {
+    object.Write("vec4 dolphin_calculate_lighting_chn{}(vec4 base_color, vec3 pos, vec3 _normal)\n",
+                 j);
     object.Write("{{\n");
 
+    object.Write("\tint4 lacc;\n"
+                 "\tvec3 ldir, h, cosAttn, distAttn;\n"
+                 "\tfloat dist, dist2, attn;\n");
+
     const bool colormatsource = !!(uid_data.matsource & (1 << j));
     if (colormatsource)  // from vertex
-      object.Write("int4 mat = int4(round({}{} * 255.0));\n", in_color_name, j);
+      object.Write("\tint4 mat = int4(round(base_color * 255.0));\n");
     else  // from color
-      object.Write("int4 mat = {}[{}];\n", I_MATERIALS, j + 2);
+      object.Write("\tint4 mat = {}[{}];\n", I_MATERIALS, j + 2);
 
     if ((uid_data.enablelighting & (1 << j)) != 0)
     {
       if ((uid_data.ambsource & (1 << j)) != 0)  // from vertex
-        object.Write("lacc = int4(round({}{} * 255.0));\n", in_color_name, j);
+        object.Write("\tlacc = int4(round(base_color * 255.0));\n");
       else  // from color
-        object.Write("lacc = {}[{}];\n", I_MATERIALS, j);
+        object.Write("\tlacc = {}[{}];\n", I_MATERIALS, j);
     }
     else
     {
-      object.Write("lacc = int4(255, 255, 255, 255);\n");
+      object.Write("\tlacc = int4(255, 255, 255, 255);\n");
     }
 
     // check if alpha is different
@@ -111,21 +114,21 @@ void GenerateLightingShaderCode(ShaderCode& object, const LightingUidData& uid_d
     if (alphamatsource != colormatsource)
     {
       if (alphamatsource)  // from vertex
-        object.Write("mat.w = int(round({}{}.w * 255.0));\n", in_color_name, j);
+        object.Write("\tmat.w = int(round(base_color.w * 255.0));\n");
       else  // from color
-        object.Write("mat.w = {}[{}].w;\n", I_MATERIALS, j + 2);
+        object.Write("\tmat.w = {}[{}].w;\n", I_MATERIALS, j + 2);
     }
 
     if ((uid_data.enablelighting & (1 << (j + 2))) != 0)
     {
       if ((uid_data.ambsource & (1 << (j + 2))) != 0)  // from vertex
-        object.Write("lacc.w = int(round({}{}.w * 255.0));\n", in_color_name, j);
+        object.Write("\tlacc.w = int(round(base_color.w * 255.0));\n");
       else  // from color
-        object.Write("lacc.w = {}[{}].w;\n", I_MATERIALS, j);
+        object.Write("\tlacc.w = {}[{}].w;\n", I_MATERIALS, j);
     }
     else
     {
-      object.Write("lacc.w = 255;\n");
+      object.Write("\tlacc.w = 255;\n");
     }
 
     if ((uid_data.enablelighting & (1 << j)) != 0)  // Color lights
@@ -144,9 +147,9 @@ void GenerateLightingShaderCode(ShaderCode& object, const LightingUidData& uid_d
           GenerateLightShader(object, uid_data, i, j + 2, true);
       }
     }
-    object.Write("lacc = clamp(lacc, 0, 255);\n");
-    object.Write("{}{} = float4((mat * (lacc + (lacc >> 7))) >> 8) / 255.0;\n", dest, j);
-    object.Write("}}\n");
+    object.Write("\tlacc = clamp(lacc, 0, 255);\n");
+    object.Write("\treturn vec4((mat * (lacc + (lacc >> 7))) >> 8) / 255.0;\n");
+    object.Write("}}\n\n");
   }
 }
 
diff --git a/Source/Core/VideoCommon/LightingShaderGen.h b/Source/Core/VideoCommon/LightingShaderGen.h
index b06ec40c4a..3e146cc07f 100644
--- a/Source/Core/VideoCommon/LightingShaderGen.h
+++ b/Source/Core/VideoCommon/LightingShaderGen.h
@@ -44,8 +44,7 @@ constexpr char s_lighting_struct[] = "struct Light {\n"
                                      "\tfloat4 dir;\n"
                                      "};\n";
 
-void GenerateLightingShaderCode(ShaderCode& object, const LightingUidData& uid_data,
-                                std::string_view in_color_name, std::string_view dest);
+void GenerateLightingShaderHeader(ShaderCode& object, const LightingUidData& uid_data);
 void GetLightingShaderUid(LightingUidData& uid_data);
 
 void GenerateCustomLightingHeaderDetails(ShaderCode* out, u32 enablelighting, u32 light_mask);
diff --git a/Source/Core/VideoCommon/PixelShaderGen.cpp b/Source/Core/VideoCommon/PixelShaderGen.cpp
index 851d127963..8a5e7d8ccd 100644
--- a/Source/Core/VideoCommon/PixelShaderGen.cpp
+++ b/Source/Core/VideoCommon/PixelShaderGen.cpp
@@ -1057,6 +1057,11 @@ ShaderCode GeneratePixelShaderCode(APIType api_type, const ShaderHostConfig& hos
     }
   }
 
+  if (per_pixel_lighting)
+  {
+    GenerateLightingShaderHeader(out, uid_data->lighting);
+  }
+
   out.Write("void main()\n{{\n");
   out.Write("\tfloat4 rawpos = gl_FragCoord;\n");
 
@@ -1124,16 +1129,15 @@ ShaderCode GeneratePixelShaderCode(APIType api_type, const ShaderHostConfig& hos
     out.Write("\tfloat3 _normal = normalize(Normal.xyz);\n\n"
               "\tfloat3 pos = WorldPos;\n");
 
-    out.Write("\tint4 lacc;\n"
-              "\tfloat3 ldir, h, cosAttn, distAttn;\n"
-              "\tfloat dist, dist2, attn;\n");
-
     // TODO: Our current constant usage code isn't able to handle more than one buffer.
     //       So we can't mark the VS constant as used here. But keep them here as reference.
     // out.SetConstantsUsed(C_PLIGHT_COLORS, C_PLIGHT_COLORS+7); // TODO: Can be optimized further
     // out.SetConstantsUsed(C_PLIGHTS, C_PLIGHTS+31); // TODO: Can be optimized further
     // out.SetConstantsUsed(C_PMATERIALS, C_PMATERIALS+3);
-    GenerateLightingShaderCode(out, uid_data->lighting, "colors_", "col");
+    for (u32 chan = 0; chan < uid_data->numColorChans; chan++)
+    {
+      out.Write("\tcol{0} = dolphin_calculate_lighting_chn{0}(colors_{0}, pos, _normal);\n", chan);
+    }
     // The number of colors available to TEV is determined by numColorChans.
     // Normally this is performed in the vertex shader after lighting, but with per-pixel lighting,
     // we need to perform it here.  (It needs to be done after lighting, as what was originally
diff --git a/Source/Core/VideoCommon/VertexShaderGen.cpp b/Source/Core/VideoCommon/VertexShaderGen.cpp
index 41512cd489..ace5cb514e 100644
--- a/Source/Core/VideoCommon/VertexShaderGen.cpp
+++ b/Source/Core/VideoCommon/VertexShaderGen.cpp
@@ -190,6 +190,7 @@ ShaderCode GenerateVertexShaderCode(APIType api_type, const ShaderHostConfig& ho
   out.Write("}};\n\n");
 
   WriteIsNanHeader(out, api_type);
+  GenerateLightingShaderHeader(out, uid_data->lighting);
 
   if (uid_data->vs_expand == VSExpand::None)
   {
@@ -434,7 +435,12 @@ ShaderCode GenerateVertexShaderCode(APIType api_type, const ShaderHostConfig& ho
             "float3 ldir, h, cosAttn, distAttn;\n"
             "float dist, dist2, attn;\n");
 
-  GenerateLightingShaderCode(out, uid_data->lighting, "vertex_color_", "o.colors_");
+  for (u32 chan = 0; chan < NUM_XF_COLOR_CHANNELS; chan++)
+  {
+    out.Write(
+        "\to.colors_{0} = dolphin_calculate_lighting_chn{0}(vertex_color_{0}, pos.xyz, _normal);\n",
+        chan);
+  }
 
   // transform texcoords
   out.Write("float4 coord = float4(0.0, 0.0, 1.0, 1.0);\n");