From 36c6e67df2c06e1b71d2059461ed7fad17bfea34 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Tue, 12 Jan 2021 18:52:13 -0300
Subject: [PATCH] Implement shader CC mode for ISCADD, X mode for ISETP and fix
 STL/STS/STG with RZ (#1901)

* Implement shader CC mode for ISCADD, X mode for ISETP and fix STS/STG with RZ

* Fix STG too and bump shader cache version

* Fix wrong name

* Fix Carry being inverted on comparison
---
 Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs    |   2 +-
 .../Instructions/InstEmitAlu.cs               | 112 ++++++++++++++----
 .../Instructions/InstEmitMemory.cs            |  18 +--
 .../Translation/EmitterContextInsts.cs        |  30 +++++
 4 files changed, 128 insertions(+), 34 deletions(-)

diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
index 40f6d18c50..366c93df2b 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
@@ -34,7 +34,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <summary>
         /// Version of the codegen (to be changed when codegen or guest format change).
         /// </summary>
-        private const ulong ShaderCodeGenVersion = 1878;
+        private const ulong ShaderCodeGenVersion = 1901;
 
         /// <summary>
         /// Creates a new instance of the shader cache.
diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
index 5d00adb564..64ac0eb780 100644
--- a/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
+++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
@@ -119,17 +119,12 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
             Operand res = context.IAdd(srcA, srcB);
 
-            bool isSubtraction = negateA || negateB;
-
             if (op.Extended)
             {
-                // Add carry, or subtract borrow.
-                res = context.IAdd(res, isSubtraction
-                    ? context.BitwiseNot(GetCF())
-                    : context.BitwiseAnd(GetCF(), Const(1)));
+                res = context.IAdd(res, context.BitwiseAnd(GetCF(), Const(1)));
             }
 
-            SetIaddFlags(context, res, srcA, srcB, op.SetCondCode, op.Extended, isSubtraction);
+            SetIaddFlags(context, res, srcA, srcB, op.SetCondCode, op.Extended);
 
             context.Copy(GetDest(context), res);
         }
@@ -317,9 +312,9 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
             Operand res = context.IAdd(srcA, srcB);
 
-            context.Copy(GetDest(context), res);
+            SetIaddFlags(context, res, srcA, srcB, op.SetCondCode, false);
 
-            // TODO: CC, X
+            context.Copy(GetDest(context), res);
         }
 
         public static void Iset(EmitterContext context)
@@ -334,7 +329,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
             Operand srcA = GetSrcA(context);
             Operand srcB = GetSrcB(context);
 
-            Operand res = GetIntComparison(context, cmpOp, srcA, srcB, isSigned);
+            Operand res = GetIntComparison(context, cmpOp, srcA, srcB, isSigned, op.Extended);
 
             Operand pred = GetPredicate39(context);
 
@@ -356,8 +351,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
                 SetZnFlags(context, res, op.SetCondCode, op.Extended);
             }
-
-            // TODO: X
         }
 
         public static void Isetp(EmitterContext context)
@@ -371,7 +364,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
             Operand srcA = GetSrcA(context);
             Operand srcB = GetSrcB(context);
 
-            Operand p0Res = GetIntComparison(context, cmpOp, srcA, srcB, isSigned);
+            Operand p0Res = GetIntComparison(context, cmpOp, srcA, srcB, isSigned, op.Extended);
 
             Operand p1Res = context.BitwiseNot(p0Res);
 
@@ -799,6 +792,84 @@ namespace Ryujinx.Graphics.Shader.Instructions
             context.Copy(GetDest(context), res);
         }
 
+        private static Operand GetIntComparison(
+            EmitterContext   context,
+            IntegerCondition cond,
+            Operand          srcA,
+            Operand          srcB,
+            bool             isSigned,
+            bool             extended)
+        {
+            return extended
+                ? GetIntComparisonExtended(context, cond, srcA, srcB, isSigned)
+                : GetIntComparison        (context, cond, srcA, srcB, isSigned);
+        }
+
+        private static Operand GetIntComparisonExtended(
+            EmitterContext   context,
+            IntegerCondition cond,
+            Operand          srcA,
+            Operand          srcB,
+            bool             isSigned)
+        {
+            Operand res;
+
+            if (cond == IntegerCondition.Always)
+            {
+                res = Const(IrConsts.True);
+            }
+            else if (cond == IntegerCondition.Never)
+            {
+                res = Const(IrConsts.False);
+            }
+            else
+            {
+                res = context.ISubtract(srcA, srcB);
+                res = context.IAdd(res, context.BitwiseNot(GetCF()));
+
+                switch (cond)
+                {
+                    case Decoders.IntegerCondition.Equal: // r = xh == yh && xl == yl
+                        res = context.BitwiseAnd(context.ICompareEqual(srcA, srcB), GetZF());
+                        break;
+                    case Decoders.IntegerCondition.Less: // r = xh < yh || (xh == yh && xl < yl)
+                        Operand notC = context.BitwiseNot(GetCF());
+                        Operand prevLt = context.BitwiseAnd(context.ICompareEqual(srcA, srcB), notC);
+                        res = isSigned
+                            ? context.BitwiseOr(context.ICompareLess(srcA, srcB), prevLt)
+                            : context.BitwiseOr(context.ICompareLessUnsigned(srcA, srcB), prevLt);
+                        break;
+                    case Decoders.IntegerCondition.LessOrEqual: // r = xh < yh || (xh == yh && xl <= yl)
+                        Operand zOrNotC = context.BitwiseOr(GetZF(), context.BitwiseNot(GetCF()));
+                        Operand prevLe = context.BitwiseAnd(context.ICompareEqual(srcA, srcB), zOrNotC);
+                        res = isSigned
+                            ? context.BitwiseOr(context.ICompareLess(srcA, srcB), prevLe)
+                            : context.BitwiseOr(context.ICompareLessUnsigned(srcA, srcB), prevLe);
+                        break;
+                    case Decoders.IntegerCondition.Greater: // r = xh > yh || (xh == yh && xl > yl)
+                        Operand notZAndC = context.BitwiseAnd(context.BitwiseNot(GetZF()), GetCF());
+                        Operand prevGt = context.BitwiseAnd(context.ICompareEqual(srcA, srcB), notZAndC);
+                        res = isSigned
+                            ? context.BitwiseOr(context.ICompareGreater(srcA, srcB), prevGt)
+                            : context.BitwiseOr(context.ICompareGreaterUnsigned(srcA, srcB), prevGt);
+                        break;
+                    case Decoders.IntegerCondition.GreaterOrEqual: // r = xh > yh || (xh == yh && xl >= yl)
+                        Operand prevGe = context.BitwiseAnd(context.ICompareEqual(srcA, srcB), GetCF());
+                        res = isSigned
+                            ? context.BitwiseOr(context.ICompareGreater(srcA, srcB), prevGe)
+                            : context.BitwiseOr(context.ICompareGreaterUnsigned(srcA, srcB), prevGe);
+                        break;
+                    case Decoders.IntegerCondition.NotEqual: // r = xh != yh || xl != yl
+                        context.BitwiseOr(context.ICompareNotEqual(srcA, srcB), context.BitwiseNot(GetZF()));
+                        break;
+                    default:
+                        throw new InvalidOperationException($"Unexpected condition \"{cond}\".");
+                }
+            }
+
+            return res;
+        }
+
         private static Operand GetIntComparison(
             EmitterContext   context,
             IntegerCondition cond,
@@ -879,20 +950,14 @@ namespace Ryujinx.Graphics.Shader.Instructions
             Operand        srcA,
             Operand        srcB,
             bool           setCC,
-            bool           extended,
-            bool           isSubtraction = false)
+            bool           extended)
         {
             if (!setCC)
             {
                 return;
             }
 
-            if (!extended || isSubtraction)
-            {
-                // C = d < a
-                context.Copy(GetCF(), context.ICompareLessUnsigned(res, srcA));
-            }
-            else
+            if (extended)
             {
                 // C = (d == a && CIn) || d < a
                 Operand tempC0 = context.ICompareEqual       (res, srcA);
@@ -902,6 +967,11 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
                 context.Copy(GetCF(), context.BitwiseOr(tempC0, tempC1));
             }
+            else
+            {
+                // C = d < a
+                context.Copy(GetCF(), context.ICompareLessUnsigned(res, srcA));
+            }
 
             // V = (d ^ a) & ~(a ^ b) < 0
             Operand tempV0 = context.BitwiseExclusiveOr(res,  srcA);
diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitMemory.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitMemory.cs
index 63f9cff787..81d5c7af2e 100644
--- a/Ryujinx.Graphics.Shader/Instructions/InstEmitMemory.cs
+++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitMemory.cs
@@ -501,7 +501,9 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
             for (int index = 0; index < count; index++)
             {
-                Register rd = new Register(op.Rd.Index + index, RegisterType.Gpr);
+                bool isRz = op.Rd.IsRZ;
+
+                Register rd = new Register(isRz ? op.Rd.Index : op.Rd.Index + index, RegisterType.Gpr);
 
                 Operand value = Register(rd);
 
@@ -525,11 +527,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
                     case MemoryRegion.Local:  context.StoreLocal (offset, value); break;
                     case MemoryRegion.Shared: context.StoreShared(offset, value); break;
                 }
-
-                if (rd.IsRZ)
-                {
-                    break;
-                }
             }
         }
 
@@ -547,7 +544,9 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
             for (int index = 0; index < count; index++)
             {
-                Register rd = new Register(op.Rd.Index + index, RegisterType.Gpr);
+                bool isRz = op.Rd.IsRZ;
+
+                Register rd = new Register(isRz ? op.Rd.Index : op.Rd.Index + index, RegisterType.Gpr);
 
                 Operand value = Register(rd);
 
@@ -559,11 +558,6 @@ namespace Ryujinx.Graphics.Shader.Instructions
                 }
 
                 context.StoreGlobal(context.IAdd(addrLow, Const(index * 4)), addrHigh, value);
-
-                if (rd.IsRZ)
-                {
-                    break;
-                }
             }
         }
 
diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
index 40f3370fa4..b2418c2e14 100644
--- a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
+++ b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
@@ -406,11 +406,41 @@ namespace Ryujinx.Graphics.Shader.Translation
             return context.Add(Instruction.CompareEqual, Local(), a, b);
         }
 
+        public static Operand ICompareGreater(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareGreater, Local(), a, b);
+        }
+
+        public static Operand ICompareGreaterOrEqual(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareGreaterOrEqual, Local(), a, b);
+        }
+
+        public static Operand ICompareGreaterOrEqualUnsigned(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareGreaterOrEqualU32, Local(), a, b);
+        }
+
+        public static Operand ICompareGreaterUnsigned(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareGreaterU32, Local(), a, b);
+        }
+
         public static Operand ICompareLess(this EmitterContext context, Operand a, Operand b)
         {
             return context.Add(Instruction.CompareLess, Local(), a, b);
         }
 
+        public static Operand ICompareLessOrEqual(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareLessOrEqual, Local(), a, b);
+        }
+
+        public static Operand ICompareLessOrEqualUnsigned(this EmitterContext context, Operand a, Operand b)
+        {
+            return context.Add(Instruction.CompareLessOrEqualU32, Local(), a, b);
+        }
+
         public static Operand ICompareLessUnsigned(this EmitterContext context, Operand a, Operand b)
         {
             return context.Add(Instruction.CompareLessU32, Local(), a, b);