From bb2f9df0a1d5e7cbd333c39cd485a42a19a772dc Mon Sep 17 00:00:00 2001
From: merry <git@mary.rs>
Date: Fri, 11 Mar 2022 02:16:32 +0000
Subject: [PATCH] KThread: Fix GetPsr mask (#3180)

* ExecutionContext: GetPstate / SetPstate

* Put it in NativeContext

* KThread: Fix GetPsr mask

* ExecutionContext: Turn methods into Pstate property

* Address nit
---
 ARMeilleure/State/ExecutionContext.cs       |  6 ++++++
 ARMeilleure/State/NativeContext.cs          | 19 +++++++++++++++++++
 Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs |  7 ++-----
 Ryujinx.Tests/Cpu/CpuTest32.cs              |  5 +----
 4 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/ARMeilleure/State/ExecutionContext.cs b/ARMeilleure/State/ExecutionContext.cs
index a6f74cd0ed..8309864f4d 100644
--- a/ARMeilleure/State/ExecutionContext.cs
+++ b/ARMeilleure/State/ExecutionContext.cs
@@ -43,6 +43,12 @@ namespace ARMeilleure.State
         public long TpidrEl0 { get; set; }
         public long Tpidr    { get; set; }
 
+        public uint Pstate
+        {
+            get => _nativeContext.GetPstate();
+            set => _nativeContext.SetPstate(value);
+        }
+
         public FPCR Fpcr { get; set; }
         public FPSR Fpsr { get; set; }
         public FPCR StandardFpcrValue => (Fpcr & (FPCR.Ahp)) | FPCR.Dn | FPCR.Fz;
diff --git a/ARMeilleure/State/NativeContext.cs b/ARMeilleure/State/NativeContext.cs
index 962783f5e2..f911f76266 100644
--- a/ARMeilleure/State/NativeContext.cs
+++ b/ARMeilleure/State/NativeContext.cs
@@ -95,6 +95,25 @@ namespace ARMeilleure.State
             GetStorage().Flags[(int)flag] = value ? 1u : 0u;
         }
 
+        public unsafe uint GetPstate()
+        {
+            uint value = 0;
+            for (int flag = 0; flag < RegisterConsts.FlagsCount; flag++)
+            {
+                value |= GetStorage().Flags[flag] != 0 ? 1u << flag : 0u;
+            }
+            return value;
+        }
+
+        public unsafe void SetPstate(uint value)
+        {
+            for (int flag = 0; flag < RegisterConsts.FlagsCount; flag++)
+            {
+                uint bit = 1u << flag;
+                GetStorage().Flags[flag] = (value & bit) == bit ? 1u : 0u;
+            }
+        }
+
         public unsafe bool GetFPStateFlag(FPState flag)
         {
             if ((uint)flag >= RegisterConsts.FpFlagsCount)
diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
index 60f5e1a87c..ee701a6903 100644
--- a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
+++ b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
@@ -658,10 +658,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
         private static uint GetPsr(ARMeilleure.State.ExecutionContext context)
         {
-            return (context.GetPstateFlag(ARMeilleure.State.PState.NFlag) ? (1U << (int)ARMeilleure.State.PState.NFlag) : 0U) |
-                   (context.GetPstateFlag(ARMeilleure.State.PState.ZFlag) ? (1U << (int)ARMeilleure.State.PState.ZFlag) : 0U) |
-                   (context.GetPstateFlag(ARMeilleure.State.PState.CFlag) ? (1U << (int)ARMeilleure.State.PState.CFlag) : 0U) |
-                   (context.GetPstateFlag(ARMeilleure.State.PState.VFlag) ? (1U << (int)ARMeilleure.State.PState.VFlag) : 0U);
+            return context.Pstate & 0xFF0FFE20;
         }
 
         private ThreadContext GetCurrentContext()
@@ -1371,7 +1368,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             PreferredCore = _originalPreferredCore;
             AffinityMask = _originalAffinityMask;
-            
+
             if (AffinityMask != affinityMask)
             {
                 if ((AffinityMask & 1UL << ActiveCore) != 0)
diff --git a/Ryujinx.Tests/Cpu/CpuTest32.cs b/Ryujinx.Tests/Cpu/CpuTest32.cs
index 384cd4b154..1ffea0b67d 100644
--- a/Ryujinx.Tests/Cpu/CpuTest32.cs
+++ b/Ryujinx.Tests/Cpu/CpuTest32.cs
@@ -283,10 +283,7 @@ namespace Ryujinx.Tests.Cpu
             }
 
             uint finalCpsr = test.FinalRegs[15];
-            for (int i = 0; i < 32; i++)
-            {
-                Assert.That(GetContext().GetPstateFlag((PState)i), Is.EqualTo((finalCpsr & (1u << i)) != 0));
-            }
+            Assert.That(GetContext().Pstate, Is.EqualTo(finalCpsr));
         }
 
         protected void SetWorkingMemory(uint offset, byte[] data)