From 4bd1ad16f93e8decf790191868690c3bd3875ee0 Mon Sep 17 00:00:00 2001
From: LDj3SNuD <35856442+LDj3SNuD@users.noreply.github.com>
Date: Thu, 25 Mar 2021 23:33:32 +0100
Subject: [PATCH] Add Sqdmulh_Ve & Sqrdmulh_Ve Inst.s with Tests. (#2139)

---
 ARMeilleure/Decoders/OpCodeTable.cs           |  4 ++
 .../Instructions/InstEmitSimdArithmetic.cs    | 30 ++++++---
 .../Instructions/InstEmitSimdHelper.cs        | 64 +++++++++++--------
 ARMeilleure/Instructions/InstName.cs          |  2 +
 Ryujinx.Tests/Cpu/CpuTestSimdRegElem.cs       | 60 +++++++++--------
 5 files changed, 96 insertions(+), 64 deletions(-)

diff --git a/ARMeilleure/Decoders/OpCodeTable.cs b/ARMeilleure/Decoders/OpCodeTable.cs
index 2eabcf15bf..028a537720 100644
--- a/ARMeilleure/Decoders/OpCodeTable.cs
+++ b/ARMeilleure/Decoders/OpCodeTable.cs
@@ -485,12 +485,16 @@ namespace ARMeilleure.Decoders
             SetA64("01011110101xxxxx101101xxxxxxxxxx", InstName.Sqdmulh_S,       InstEmit.Sqdmulh_S,       OpCodeSimdReg.Create);
             SetA64("0x001110011xxxxx101101xxxxxxxxxx", InstName.Sqdmulh_V,       InstEmit.Sqdmulh_V,       OpCodeSimdReg.Create);
             SetA64("0x001110101xxxxx101101xxxxxxxxxx", InstName.Sqdmulh_V,       InstEmit.Sqdmulh_V,       OpCodeSimdReg.Create);
+            SetA64("0x00111101xxxxxx1100x0xxxxxxxxxx", InstName.Sqdmulh_Ve,      InstEmit.Sqdmulh_Ve,      OpCodeSimdRegElem.Create);
+            SetA64("0x00111110xxxxxx1100x0xxxxxxxxxx", InstName.Sqdmulh_Ve,      InstEmit.Sqdmulh_Ve,      OpCodeSimdRegElem.Create);
             SetA64("01111110xx100000011110xxxxxxxxxx", InstName.Sqneg_S,         InstEmit.Sqneg_S,         OpCodeSimd.Create);
             SetA64("0>101110<<100000011110xxxxxxxxxx", InstName.Sqneg_V,         InstEmit.Sqneg_V,         OpCodeSimd.Create);
             SetA64("01111110011xxxxx101101xxxxxxxxxx", InstName.Sqrdmulh_S,      InstEmit.Sqrdmulh_S,      OpCodeSimdReg.Create);
             SetA64("01111110101xxxxx101101xxxxxxxxxx", InstName.Sqrdmulh_S,      InstEmit.Sqrdmulh_S,      OpCodeSimdReg.Create);
             SetA64("0x101110011xxxxx101101xxxxxxxxxx", InstName.Sqrdmulh_V,      InstEmit.Sqrdmulh_V,      OpCodeSimdReg.Create);
             SetA64("0x101110101xxxxx101101xxxxxxxxxx", InstName.Sqrdmulh_V,      InstEmit.Sqrdmulh_V,      OpCodeSimdReg.Create);
+            SetA64("0x00111101xxxxxx1101x0xxxxxxxxxx", InstName.Sqrdmulh_Ve,     InstEmit.Sqrdmulh_Ve,     OpCodeSimdRegElem.Create);
+            SetA64("0x00111110xxxxxx1101x0xxxxxxxxxx", InstName.Sqrdmulh_Ve,     InstEmit.Sqrdmulh_Ve,     OpCodeSimdRegElem.Create);
             SetA64("0>001110<<1xxxxx010111xxxxxxxxxx", InstName.Sqrshl_V,        InstEmit.Sqrshl_V,        OpCodeSimdReg.Create);
             SetA64("0101111100>>>xxx100111xxxxxxxxxx", InstName.Sqrshrn_S,       InstEmit.Sqrshrn_S,       OpCodeSimdShImm.Create);
             SetA64("0x00111100>>>xxx100111xxxxxxxxxx", InstName.Sqrshrn_V,       InstEmit.Sqrshrn_V,       OpCodeSimdShImm.Create);
diff --git a/ARMeilleure/Instructions/InstEmitSimdArithmetic.cs b/ARMeilleure/Instructions/InstEmitSimdArithmetic.cs
index 9c35988294..eff6bf3599 100644
--- a/ARMeilleure/Instructions/InstEmitSimdArithmetic.cs
+++ b/ARMeilleure/Instructions/InstEmitSimdArithmetic.cs
@@ -2642,22 +2642,27 @@ namespace ARMeilleure.Instructions
 
         public static void Sqadd_S(ArmEmitterContext context)
         {
-            EmitScalarSaturatingBinaryOpSx(context, SaturatingFlags.Add);
+            EmitScalarSaturatingBinaryOpSx(context, flags: SaturatingFlags.Add);
         }
 
         public static void Sqadd_V(ArmEmitterContext context)
         {
-            EmitVectorSaturatingBinaryOpSx(context, SaturatingFlags.Add);
+            EmitVectorSaturatingBinaryOpSx(context, flags: SaturatingFlags.Add);
         }
 
         public static void Sqdmulh_S(ArmEmitterContext context)
         {
-            EmitSaturatingBinaryOp(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: false), SaturatingFlags.ScalarSx);
+            EmitScalarSaturatingBinaryOpSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: false));
         }
 
         public static void Sqdmulh_V(ArmEmitterContext context)
         {
-            EmitSaturatingBinaryOp(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: false), SaturatingFlags.VectorSx);
+            EmitVectorSaturatingBinaryOpSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: false));
+        }
+
+        public static void Sqdmulh_Ve(ArmEmitterContext context)
+        {
+            EmitVectorSaturatingBinaryOpByElemSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: false));
         }
 
         public static void Sqneg_S(ArmEmitterContext context)
@@ -2672,22 +2677,27 @@ namespace ARMeilleure.Instructions
 
         public static void Sqrdmulh_S(ArmEmitterContext context)
         {
-            EmitSaturatingBinaryOp(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: true), SaturatingFlags.ScalarSx);
+            EmitScalarSaturatingBinaryOpSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: true));
         }
 
         public static void Sqrdmulh_V(ArmEmitterContext context)
         {
-            EmitSaturatingBinaryOp(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: true), SaturatingFlags.VectorSx);
+            EmitVectorSaturatingBinaryOpSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: true));
+        }
+
+        public static void Sqrdmulh_Ve(ArmEmitterContext context)
+        {
+            EmitVectorSaturatingBinaryOpByElemSx(context, (op1, op2) => EmitDoublingMultiplyHighHalf(context, op1, op2, round: true));
         }
 
         public static void Sqsub_S(ArmEmitterContext context)
         {
-            EmitScalarSaturatingBinaryOpSx(context, SaturatingFlags.Sub);
+            EmitScalarSaturatingBinaryOpSx(context, flags: SaturatingFlags.Sub);
         }
 
         public static void Sqsub_V(ArmEmitterContext context)
         {
-            EmitVectorSaturatingBinaryOpSx(context, SaturatingFlags.Sub);
+            EmitVectorSaturatingBinaryOpSx(context, flags: SaturatingFlags.Sub);
         }
 
         public static void Sqxtn_S(ArmEmitterContext context)
@@ -2850,12 +2860,12 @@ namespace ARMeilleure.Instructions
 
         public static void Suqadd_S(ArmEmitterContext context)
         {
-            EmitScalarSaturatingBinaryOpSx(context, SaturatingFlags.Accumulate);
+            EmitScalarSaturatingBinaryOpSx(context, flags: SaturatingFlags.Accumulate);
         }
 
         public static void Suqadd_V(ArmEmitterContext context)
         {
-            EmitVectorSaturatingBinaryOpSx(context, SaturatingFlags.Accumulate);
+            EmitVectorSaturatingBinaryOpSx(context, flags: SaturatingFlags.Accumulate);
         }
 
         public static void Uaba_V(ArmEmitterContext context)
diff --git a/ARMeilleure/Instructions/InstEmitSimdHelper.cs b/ARMeilleure/Instructions/InstEmitSimdHelper.cs
index d362ad5e7a..576f4229e6 100644
--- a/ARMeilleure/Instructions/InstEmitSimdHelper.cs
+++ b/ARMeilleure/Instructions/InstEmitSimdHelper.cs
@@ -1302,32 +1302,29 @@ namespace ARMeilleure.Instructions
         [Flags]
         public enum SaturatingFlags
         {
-            Scalar = 1 << 0,
-            Signed = 1 << 1,
+            None = 0,
 
-            Add = 1 << 2,
-            Sub = 1 << 3,
+            ByElem = 1 << 0,
+            Scalar = 1 << 1,
+            Signed = 1 << 2,
 
-            Accumulate = 1 << 4,
+            Add = 1 << 3,
+            Sub = 1 << 4,
 
-            ScalarSx = Scalar | Signed,
-            ScalarZx = Scalar,
-
-            VectorSx = Signed,
-            VectorZx = 0
+            Accumulate = 1 << 5
         }
 
         public static void EmitScalarSaturatingUnaryOpSx(ArmEmitterContext context, Func1I emit)
         {
-            EmitSaturatingUnaryOpSx(context, emit, SaturatingFlags.ScalarSx);
+            EmitSaturatingUnaryOpSx(context, emit, SaturatingFlags.Scalar | SaturatingFlags.Signed);
         }
 
         public static void EmitVectorSaturatingUnaryOpSx(ArmEmitterContext context, Func1I emit)
         {
-            EmitSaturatingUnaryOpSx(context, emit, SaturatingFlags.VectorSx);
+            EmitSaturatingUnaryOpSx(context, emit, SaturatingFlags.Signed);
         }
 
-        private static void EmitSaturatingUnaryOpSx(ArmEmitterContext context, Func1I emit, SaturatingFlags flags)
+        public static void EmitSaturatingUnaryOpSx(ArmEmitterContext context, Func1I emit, SaturatingFlags flags)
         {
             OpCodeSimd op = (OpCodeSimd)context.CurrOp;
 
@@ -1357,24 +1354,29 @@ namespace ARMeilleure.Instructions
             context.Copy(GetVec(op.Rd), res);
         }
 
-        public static void EmitScalarSaturatingBinaryOpSx(ArmEmitterContext context, SaturatingFlags flags)
+        public static void EmitScalarSaturatingBinaryOpSx(ArmEmitterContext context, Func2I emit = null, SaturatingFlags flags = SaturatingFlags.None)
         {
-            EmitSaturatingBinaryOp(context, null, SaturatingFlags.ScalarSx | flags);
+            EmitSaturatingBinaryOp(context, emit, SaturatingFlags.Scalar | SaturatingFlags.Signed | flags);
         }
 
         public static void EmitScalarSaturatingBinaryOpZx(ArmEmitterContext context, SaturatingFlags flags)
         {
-            EmitSaturatingBinaryOp(context, null, SaturatingFlags.ScalarZx | flags);
+            EmitSaturatingBinaryOp(context, null, SaturatingFlags.Scalar | flags);
         }
 
-        public static void EmitVectorSaturatingBinaryOpSx(ArmEmitterContext context, SaturatingFlags flags)
+        public static void EmitVectorSaturatingBinaryOpSx(ArmEmitterContext context, Func2I emit = null, SaturatingFlags flags = SaturatingFlags.None)
         {
-            EmitSaturatingBinaryOp(context, null, SaturatingFlags.VectorSx | flags);
+            EmitSaturatingBinaryOp(context, emit, SaturatingFlags.Signed | flags);
         }
 
         public static void EmitVectorSaturatingBinaryOpZx(ArmEmitterContext context, SaturatingFlags flags)
         {
-            EmitSaturatingBinaryOp(context, null, SaturatingFlags.VectorZx | flags);
+            EmitSaturatingBinaryOp(context, null, flags);
+        }
+
+        public static void EmitVectorSaturatingBinaryOpByElemSx(ArmEmitterContext context, Func2I emit)
+        {
+            EmitSaturatingBinaryOp(context, emit, SaturatingFlags.ByElem | SaturatingFlags.Signed);
         }
 
         public static void EmitSaturatingBinaryOp(ArmEmitterContext context, Func2I emit, SaturatingFlags flags)
@@ -1383,6 +1385,7 @@ namespace ARMeilleure.Instructions
 
             Operand res = context.VectorZero();
 
+            bool byElem = (flags & SaturatingFlags.ByElem) != 0;
             bool scalar = (flags & SaturatingFlags.Scalar) != 0;
             bool signed = (flags & SaturatingFlags.Signed) != 0;
 
@@ -1395,13 +1398,11 @@ namespace ARMeilleure.Instructions
 
             if (add || sub)
             {
-                OpCodeSimdReg opReg = (OpCodeSimdReg)op;
-
                 for (int index = 0; index < elems; index++)
                 {
                     Operand de;
-                    Operand ne = EmitVectorExtract(context, opReg.Rn, index, op.Size, signed);
-                    Operand me = EmitVectorExtract(context, opReg.Rm, index, op.Size, signed);
+                    Operand ne = EmitVectorExtract(context, op.Rn, index, op.Size, signed);
+                    Operand me = EmitVectorExtract(context, ((OpCodeSimdReg)op).Rm, index, op.Size, signed);
 
                     if (op.Size <= 2)
                     {
@@ -1445,12 +1446,23 @@ namespace ARMeilleure.Instructions
             }
             else
             {
-                OpCodeSimdReg opReg = (OpCodeSimdReg)op;
+                Operand me = null;
+
+                if (byElem)
+                {
+                    OpCodeSimdRegElem opRegElem = (OpCodeSimdRegElem)op;
+
+                    me = EmitVectorExtract(context, opRegElem.Rm, opRegElem.Index, op.Size, signed);
+                }
 
                 for (int index = 0; index < elems; index++)
                 {
-                    Operand ne = EmitVectorExtract(context, opReg.Rn, index, op.Size, signed);
-                    Operand me = EmitVectorExtract(context, opReg.Rm, index, op.Size, signed);
+                    Operand ne = EmitVectorExtract(context, op.Rn, index, op.Size, signed);
+
+                    if (!byElem)
+                    {
+                        me = EmitVectorExtract(context, ((OpCodeSimdReg)op).Rm, index, op.Size, signed);
+                    }
 
                     Operand de = EmitSatQ(context, emit(ne, me), op.Size, true, signed);
 
diff --git a/ARMeilleure/Instructions/InstName.cs b/ARMeilleure/Instructions/InstName.cs
index 458ecf2ff6..fe7644a917 100644
--- a/ARMeilleure/Instructions/InstName.cs
+++ b/ARMeilleure/Instructions/InstName.cs
@@ -358,10 +358,12 @@ namespace ARMeilleure.Instructions
         Sqadd_V,
         Sqdmulh_S,
         Sqdmulh_V,
+        Sqdmulh_Ve,
         Sqneg_S,
         Sqneg_V,
         Sqrdmulh_S,
         Sqrdmulh_V,
+        Sqrdmulh_Ve,
         Sqrshl_V,
         Sqrshrn_S,
         Sqrshrn_V,
diff --git a/Ryujinx.Tests/Cpu/CpuTestSimdRegElem.cs b/Ryujinx.Tests/Cpu/CpuTestSimdRegElem.cs
index 23e0e36465..5d0a8f3f9d 100644
--- a/Ryujinx.Tests/Cpu/CpuTestSimdRegElem.cs
+++ b/Ryujinx.Tests/Cpu/CpuTestSimdRegElem.cs
@@ -26,23 +26,27 @@ namespace Ryujinx.Tests.Cpu
 #endregion
 
 #region "ValueSource (Opcodes)"
-        private static uint[] _Mla_Mls_Mul_Ve_4H_8H_()
+        private static uint[] _Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_4H_8H_()
         {
             return new uint[]
             {
-                0x2F400000u, // MLA V0.4H, V0.4H, V0.H[0]
-                0x2F404000u, // MLS V0.4H, V0.4H, V0.H[0]
-                0x0F408000u  // MUL V0.4H, V0.4H, V0.H[0]
+                0x2F400000u, // MLA      V0.4H, V0.4H, V0.H[0]
+                0x2F404000u, // MLS      V0.4H, V0.4H, V0.H[0]
+                0x0F408000u, // MUL      V0.4H, V0.4H, V0.H[0]
+                0x0F40C000u, // SQDMULH  V0.4H, V0.4H, V0.H[0]
+                0x0F40D000u  // SQRDMULH V0.4H, V0.4H, V0.H[0]
             };
         }
 
-        private static uint[] _Mla_Mls_Mul_Ve_2S_4S_()
+        private static uint[] _Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_2S_4S_()
         {
             return new uint[]
             {
-                0x2F800000u, // MLA V0.2S, V0.2S, V0.S[0]
-                0x2F804000u, // MLS V0.2S, V0.2S, V0.S[0]
-                0x0F808000u  // MUL V0.2S, V0.2S, V0.S[0]
+                0x2F800000u, // MLA      V0.2S, V0.2S, V0.S[0]
+                0x2F804000u, // MLS      V0.2S, V0.2S, V0.S[0]
+                0x0F808000u, // MUL      V0.2S, V0.2S, V0.S[0]
+                0x0F80C000u, // SQDMULH  V0.2S, V0.2S, V0.S[0]
+                0x0F80D000u  // SQRDMULH V0.2S, V0.2S, V0.S[0]
             };
         }
 
@@ -77,15 +81,15 @@ namespace Ryujinx.Tests.Cpu
         private const int RndCntIndex = 2;
 
         [Test, Pairwise]
-        public void Mla_Mls_Mul_Ve_4H_8H([ValueSource("_Mla_Mls_Mul_Ve_4H_8H_")] uint opcodes,
-                                         [Values(0u)]     uint rd,
-                                         [Values(1u, 0u)] uint rn,
-                                         [Values(2u, 0u)] uint rm,
-                                         [ValueSource("_4H_")] [Random(RndCnt)] ulong z,
-                                         [ValueSource("_4H_")] [Random(RndCnt)] ulong a,
-                                         [ValueSource("_4H_")] [Random(RndCnt)] ulong b,
-                                         [Values(0u, 7u)] [Random(1u, 6u, RndCntIndex)] uint index,
-                                         [Values(0b0u, 0b1u)] uint q) // <4H, 8H>
+        public void Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_4H_8H([ValueSource(nameof(_Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_4H_8H_))] uint opcodes,
+                                                          [Values(0u)]     uint rd,
+                                                          [Values(1u, 0u)] uint rn,
+                                                          [Values(2u, 0u)] uint rm,
+                                                          [ValueSource(nameof(_4H_))] [Random(RndCnt)] ulong z,
+                                                          [ValueSource(nameof(_4H_))] [Random(RndCnt)] ulong a,
+                                                          [ValueSource(nameof(_4H_))] [Random(RndCnt)] ulong b,
+                                                          [Values(0u, 7u)] [Random(1u, 6u, RndCntIndex)] uint index,
+                                                          [Values(0b0u, 0b1u)] uint q) // <4H, 8H>
         {
             uint h = (index >> 2) & 1;
             uint l = (index >> 1) & 1;
@@ -101,19 +105,19 @@ namespace Ryujinx.Tests.Cpu
 
             SingleOpcode(opcodes, v0: v0, v1: v1, v2: v2);
 
-            CompareAgainstUnicorn();
+            CompareAgainstUnicorn(fpsrMask: Fpsr.Qc);
         }
 
         [Test, Pairwise]
-        public void Mla_Mls_Mul_Ve_2S_4S([ValueSource("_Mla_Mls_Mul_Ve_2S_4S_")] uint opcodes,
-                                         [Values(0u)]     uint rd,
-                                         [Values(1u, 0u)] uint rn,
-                                         [Values(2u, 0u)] uint rm,
-                                         [ValueSource("_2S_")] [Random(RndCnt)] ulong z,
-                                         [ValueSource("_2S_")] [Random(RndCnt)] ulong a,
-                                         [ValueSource("_2S_")] [Random(RndCnt)] ulong b,
-                                         [Values(0u, 1u, 2u, 3u)] uint index,
-                                         [Values(0b0u, 0b1u)] uint q) // <2S, 4S>
+        public void Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_2S_4S([ValueSource(nameof(_Mla_Mls_Mul_Sqdmulh_Sqrdmulh_Ve_2S_4S_))] uint opcodes,
+                                                          [Values(0u)]     uint rd,
+                                                          [Values(1u, 0u)] uint rn,
+                                                          [Values(2u, 0u)] uint rm,
+                                                          [ValueSource(nameof(_2S_))] [Random(RndCnt)] ulong z,
+                                                          [ValueSource(nameof(_2S_))] [Random(RndCnt)] ulong a,
+                                                          [ValueSource(nameof(_2S_))] [Random(RndCnt)] ulong b,
+                                                          [Values(0u, 1u, 2u, 3u)] uint index,
+                                                          [Values(0b0u, 0b1u)] uint q) // <2S, 4S>
         {
             uint h = (index >> 1) & 1;
             uint l = index & 1;
@@ -128,7 +132,7 @@ namespace Ryujinx.Tests.Cpu
 
             SingleOpcode(opcodes, v0: v0, v1: v1, v2: v2);
 
-            CompareAgainstUnicorn();
+            CompareAgainstUnicorn(fpsrMask: Fpsr.Qc);
         }
 
         [Test, Pairwise]