From fc26189fe1338ffcba12d83a922da9c706738902 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Fri, 19 May 2023 11:52:31 -0300
Subject: [PATCH] Eliminate redundant multiplications by gl_FragCoord.w on the
 shader (#4578)

* Eliminate redundant multiplications by gl_FragCoord.w on the shader

* Shader cache version bump
---
 .../Shader/DiskCache/DiskCacheHostStorage.cs  |  2 +-
 .../Translation/Optimizations/Optimizer.cs    | 75 +++++++++++++++++++
 .../Translation/Optimizations/Utils.cs        | 29 +++++++
 3 files changed, 105 insertions(+), 1 deletion(-)

diff --git a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
index 71098efa81..ee77145218 100644
--- a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
+++ b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
         private const ushort FileFormatVersionMajor = 1;
         private const ushort FileFormatVersionMinor = 2;
         private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
-        private const uint CodeGenVersion = 4892;
+        private const uint CodeGenVersion = 4578;
 
         private const string SharedTocFileName = "shared.toc";
         private const string SharedDataFileName = "shared.data";
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs b/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs
index 16848bdc84..b41e47e42d 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs
@@ -20,6 +20,12 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
                 GlobalToStorage.RunPass(blocks[blkIndex], config, ref sbUseMask, ref ubeUseMask);
                 BindlessToIndexed.RunPass(blocks[blkIndex], config);
                 BindlessElimination.RunPass(blocks[blkIndex], config);
+
+                // FragmentCoord only exists on fragment shaders, so we don't need to check other stages.
+                if (config.Stage == ShaderStage.Fragment)
+                {
+                    EliminateMultiplyByFragmentCoordW(blocks[blkIndex]);
+                }
             }
 
             config.SetAccessibleBufferMasks(sbUseMask, ubeUseMask);
@@ -281,6 +287,75 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
             return modified;
         }
 
+        private static void EliminateMultiplyByFragmentCoordW(BasicBlock block)
+        {
+            foreach (INode node in block.Operations)
+            {
+                if (node is Operation operation)
+                {
+                    EliminateMultiplyByFragmentCoordW(operation);
+                }
+            }
+        }
+
+        private static void EliminateMultiplyByFragmentCoordW(Operation operation)
+        {
+            // We're looking for the pattern:
+            //  y = x * gl_FragCoord.w
+            //  v = y * (1.0 / gl_FragCoord.w)
+            // Then we transform it into:
+            //  v = x
+            // This pattern is common on fragment shaders due to the way how perspective correction is done.
+
+            // We are expecting a multiplication by the reciprocal of gl_FragCoord.w.
+            if (operation.Inst != (Instruction.FP32 | Instruction.Multiply))
+            {
+                return;
+            }
+
+            Operand lhs = operation.GetSource(0);
+            Operand rhs = operation.GetSource(1);
+
+            // Check LHS of the the main multiplication operation. We expect an input being multiplied by gl_FragCoord.w.
+            if (!(lhs.AsgOp is Operation attrMulOp) || attrMulOp.Inst != (Instruction.FP32 | Instruction.Multiply))
+            {
+                return;
+            }
+
+            Operand attrMulLhs = attrMulOp.GetSource(0);
+            Operand attrMulRhs = attrMulOp.GetSource(1);
+
+            // LHS should be any input, RHS should be exactly gl_FragCoord.w.
+            if (!Utils.IsInputLoad(attrMulLhs.AsgOp) || !Utils.IsInputLoad(attrMulRhs.AsgOp, IoVariable.FragmentCoord, 3))
+            {
+                return;
+            }
+
+            // RHS of the main multiplication should be a reciprocal operation (1.0 / x).
+            if (!(rhs.AsgOp is Operation reciprocalOp) || reciprocalOp.Inst != (Instruction.FP32 | Instruction.Divide))
+            {
+                return;
+            }
+
+            Operand reciprocalLhs = reciprocalOp.GetSource(0);
+            Operand reciprocalRhs = reciprocalOp.GetSource(1);
+
+            // Check if the divisor is a constant equal to 1.0.
+            if (reciprocalLhs.Type != OperandType.Constant || reciprocalLhs.AsFloat() != 1.0f)
+            {
+                return;
+            }
+
+            // Check if the dividend is gl_FragCoord.w.
+            if (!Utils.IsInputLoad(reciprocalRhs.AsgOp, IoVariable.FragmentCoord, 3))
+            {
+                return;
+            }
+
+            // If everything matches, we can replace the operation with the input load result.
+            operation.TurnIntoCopy(attrMulLhs);
+        }
+
         private static void RemoveNode(BasicBlock block, LinkedListNode<INode> llNode)
         {
             // Remove a node from the nodes list, and also remove itself
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Utils.cs b/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Utils.cs
index 4ca6d68778..a0d58d0793 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Utils.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/Optimizations/Utils.cs
@@ -4,6 +4,35 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
 {
     static class Utils
     {
+        public static bool IsInputLoad(INode node)
+        {
+            return (node is Operation operation) &&
+                   operation.Inst == Instruction.Load &&
+                   operation.StorageKind == StorageKind.Input;
+        }
+
+        public static bool IsInputLoad(INode node, IoVariable ioVariable, int elemIndex)
+        {
+            if (!(node is Operation operation) ||
+                operation.Inst != Instruction.Load ||
+                operation.StorageKind != StorageKind.Input ||
+                operation.SourcesCount != 2)
+            {
+                return false;
+            }
+
+            Operand ioVariableSrc = operation.GetSource(0);
+
+            if (ioVariableSrc.Type != OperandType.Constant || (IoVariable)ioVariableSrc.Value != ioVariable)
+            {
+                return false;
+            }
+
+            Operand elemIndexSrc = operation.GetSource(1);
+
+            return elemIndexSrc.Type == OperandType.Constant && elemIndexSrc.Value == elemIndex;
+        }
+
         private static Operation FindBranchSource(BasicBlock block)
         {
             foreach (BasicBlock sourceBlock in block.Predecessors)