From 0bcbe32367eeada2a5aa7e6bb2edccc22cababa3 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Sun, 6 Mar 2022 16:42:13 -0300
Subject: [PATCH] Only initialize shader outputs that are actually used on the
 next stage (#3054)

* Only initialize shader outputs that are actually used on the next stage

* Shader cache version bump
---
 Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs    |  2 +-
 Ryujinx.Graphics.Shader/Decoders/Decoder.cs   |  5 +-
 .../Translation/ShaderConfig.cs               | 11 ++-
 .../Translation/Translator.cs                 | 26 ++++---
 .../Translation/UInt128.cs                    | 74 +++++++++++++++++++
 5 files changed, 104 insertions(+), 14 deletions(-)
 create mode 100644 Ryujinx.Graphics.Shader/Translation/UInt128.cs

diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
index ae128ed432..e6d46884a7 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
@@ -40,7 +40,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <summary>
         /// Version of the codegen (to be changed when codegen or guest format change).
         /// </summary>
-        private const ulong ShaderCodeGenVersion = 3132;
+        private const ulong ShaderCodeGenVersion = 3054;
 
         // Progress reporting helpers
         private volatile int _shaderCount;
diff --git a/Ryujinx.Graphics.Shader/Decoders/Decoder.cs b/Ryujinx.Graphics.Shader/Decoders/Decoder.cs
index 8820527f13..6fa4055aa3 100644
--- a/Ryujinx.Graphics.Shader/Decoders/Decoder.cs
+++ b/Ryujinx.Graphics.Shader/Decoders/Decoder.cs
@@ -308,7 +308,8 @@ namespace Ryujinx.Graphics.Shader.Decoders
                     int attr = offset + elemIndex * 4;
                     if (attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd)
                     {
-                        int index = (attr - AttributeConsts.UserAttributeBase) / 16;
+                        int userAttr = attr - AttributeConsts.UserAttributeBase;
+                        int index = userAttr / 16;
 
                         if (isStore)
                         {
@@ -316,7 +317,7 @@ namespace Ryujinx.Graphics.Shader.Decoders
                         }
                         else
                         {
-                            config.SetInputUserAttribute(index, perPatch);
+                            config.SetInputUserAttribute(index, (userAttr >> 2) & 3, perPatch);
                         }
                     }
 
diff --git a/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs b/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
index 3b10ab2117..8be7ceaeaf 100644
--- a/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
+++ b/Ryujinx.Graphics.Shader/Translation/ShaderConfig.cs
@@ -54,6 +54,11 @@ namespace Ryujinx.Graphics.Shader.Translation
         private int _nextUsedInputAttributes;
         private int _thisUsedInputAttributes;
 
+        public UInt128 NextInputAttributesComponents { get; private set; }
+        public UInt128 ThisInputAttributesComponents { get; private set; }
+        public UInt128 NextInputAttributesPerPatchComponents { get; private set; }
+        public UInt128 ThisInputAttributesPerPatchComponents { get; private set; }
+
         private int _usedConstantBuffers;
         private int _usedStorageBuffers;
         private int _usedStorageBuffersWrite;
@@ -227,11 +232,12 @@ namespace Ryujinx.Graphics.Shader.Translation
             UsedOutputAttributes |= 1 << index;
         }
 
-        public void SetInputUserAttribute(int index, bool perPatch)
+        public void SetInputUserAttribute(int index, int component, bool perPatch)
         {
             if (perPatch)
             {
                 UsedInputAttributesPerPatch |= 1 << index;
+                ThisInputAttributesPerPatchComponents |= UInt128.Pow2(index * 4 + component);
             }
             else
             {
@@ -239,6 +245,7 @@ namespace Ryujinx.Graphics.Shader.Translation
 
                 UsedInputAttributes |= mask;
                 _thisUsedInputAttributes |= mask;
+                ThisInputAttributesComponents |= UInt128.Pow2(index * 4 + component);
             }
         }
 
@@ -256,6 +263,8 @@ namespace Ryujinx.Graphics.Shader.Translation
 
         public void MergeFromtNextStage(ShaderConfig config)
         {
+            NextInputAttributesComponents = config.ThisInputAttributesComponents;
+            NextInputAttributesPerPatchComponents = config.ThisInputAttributesPerPatchComponents;
             NextUsesFixedFuncAttributes = config.UsedFeatures.HasFlag(FeatureFlags.FixedFuncAttr);
             MergeOutputUserAttributes(config.UsedInputAttributes, config.UsedInputAttributesPerPatch);
         }
diff --git a/Ryujinx.Graphics.Shader/Translation/Translator.cs b/Ryujinx.Graphics.Shader/Translation/Translator.cs
index 603b20d649..e594c818a0 100644
--- a/Ryujinx.Graphics.Shader/Translation/Translator.cs
+++ b/Ryujinx.Graphics.Shader/Translation/Translator.cs
@@ -214,24 +214,24 @@ namespace Ryujinx.Graphics.Shader.Translation
                 InitializeOutput(context, AttributeConsts.PositionX, perPatch: false);
             }
 
-            int usedAttributes = context.Config.UsedOutputAttributes;
-            while (usedAttributes != 0)
+            UInt128 usedAttributes = context.Config.NextInputAttributesComponents;
+            while (usedAttributes != UInt128.Zero)
             {
-                int index = BitOperations.TrailingZeroCount(usedAttributes);
+                int index = usedAttributes.TrailingZeroCount();
 
-                InitializeOutput(context, AttributeConsts.UserAttributeBase + index * 16, perPatch: false);
+                InitializeOutputComponent(context, AttributeConsts.UserAttributeBase + index * 4, perPatch: false);
 
-                usedAttributes &= ~(1 << index);
+                usedAttributes &= ~UInt128.Pow2(index);
             }
 
-            int usedAttributesPerPatch = context.Config.UsedOutputAttributesPerPatch;
-            while (usedAttributesPerPatch != 0)
+            UInt128 usedAttributesPerPatch = context.Config.NextInputAttributesPerPatchComponents;
+            while (usedAttributesPerPatch != UInt128.Zero)
             {
-                int index = BitOperations.TrailingZeroCount(usedAttributesPerPatch);
+                int index = usedAttributesPerPatch.TrailingZeroCount();
 
-                InitializeOutput(context, AttributeConsts.UserAttributeBase + index * 16, perPatch: true);
+                InitializeOutputComponent(context, AttributeConsts.UserAttributeBase + index * 4, perPatch: true);
 
-                usedAttributesPerPatch &= ~(1 << index);
+                usedAttributesPerPatch &= ~UInt128.Pow2(index);
             }
 
             if (config.NextUsesFixedFuncAttributes)
@@ -260,6 +260,12 @@ namespace Ryujinx.Graphics.Shader.Translation
             }
         }
 
+        private static void InitializeOutputComponent(EmitterContext context, int attrOffset, bool perPatch)
+        {
+            int c = (attrOffset >> 2) & 3;
+            context.Copy(perPatch ? AttributePerPatch(attrOffset) : Attribute(attrOffset), ConstF(c == 3 ? 1f : 0f));
+        }
+
         private static void EmitOps(EmitterContext context, Block block)
         {
             for (int opIndex = 0; opIndex < block.OpCodes.Count; opIndex++)
diff --git a/Ryujinx.Graphics.Shader/Translation/UInt128.cs b/Ryujinx.Graphics.Shader/Translation/UInt128.cs
new file mode 100644
index 0000000000..590f7690ff
--- /dev/null
+++ b/Ryujinx.Graphics.Shader/Translation/UInt128.cs
@@ -0,0 +1,74 @@
+using System;
+using System.Numerics;
+
+namespace Ryujinx.Graphics.Shader.Translation
+{
+    struct UInt128 : IEquatable<UInt128>
+    {
+        public static UInt128 Zero => new UInt128() { _v0 = 0, _v1 = 0 };
+
+        private ulong _v0;
+        private ulong _v1;
+
+        public int TrailingZeroCount()
+        {
+            int count = BitOperations.TrailingZeroCount(_v0);
+            if (count == 64)
+            {
+                count += BitOperations.TrailingZeroCount(_v1);
+            }
+
+            return count;
+        }
+
+        public static UInt128 Pow2(int x)
+        {
+            if (x >= 64)
+            {
+                return new UInt128() { _v0 = 0, _v1 = 1UL << (x - 64 ) };
+            }
+
+            return new UInt128() { _v0 = 1UL << x, _v1 = 0 };
+        }
+
+        public static UInt128 operator ~(UInt128 x)
+        {
+            return new UInt128() { _v0 = ~x._v0, _v1 = ~x._v1 };
+        }
+
+        public static UInt128 operator &(UInt128 x, UInt128 y)
+        {
+            return new UInt128() { _v0 = x._v0 & y._v0, _v1 = x._v1 & y._v1 };
+        }
+
+        public static UInt128 operator |(UInt128 x, UInt128 y)
+        {
+            return new UInt128() { _v0 = x._v0 | y._v0, _v1 = x._v1 | y._v1 };
+        }
+
+        public static bool operator ==(UInt128 x, UInt128 y)
+        {
+            return x.Equals(y);
+        }
+
+        public static bool operator !=(UInt128 x, UInt128 y)
+        {
+            return !x.Equals(y);
+        }
+
+        public override bool Equals(object obj)
+        {
+            return obj is UInt128 other && Equals(other);
+        }
+
+        public bool Equals(UInt128 other)
+        {
+            return _v0 == other._v0 && _v1 == other._v1;
+        }
+
+        public override int GetHashCode()
+        {
+            return HashCode.Combine(_v0, _v1);
+        }
+    }
+}
\ No newline at end of file