From 20774dab14ca8362e716ce87f975be7ea77beead Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Fri, 17 Jul 2020 01:22:13 -0300
Subject: [PATCH] Improve kernel WaitSynchronization syscall implementation
 (#1362)

---
 .../HOS/Kernel/Common/KernelTransfer.cs       |  3 +-
 .../HOS/Kernel/SupervisorCall/Syscall.cs      | 83 +++++++++++++++----
 .../HOS/Kernel/Threading/KSynchronization.cs  |  3 +-
 Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs   |  8 ++
 4 files changed, 81 insertions(+), 16 deletions(-)

diff --git a/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs b/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs
index 8b739f66e4..d57ca481d2 100644
--- a/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs
+++ b/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs
@@ -1,5 +1,6 @@
 using Ryujinx.Cpu;
 using Ryujinx.HLE.HOS.Kernel.Process;
+using System;
 
 namespace Ryujinx.HLE.HOS.Kernel.Common
 {
@@ -22,7 +23,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Common
             return false;
         }
 
-        public static bool UserToKernelInt32Array(KernelContext context, ulong address, int[] values)
+        public static bool UserToKernelInt32Array(KernelContext context, ulong address, Span<int> values)
         {
             KProcess currentProcess = context.Scheduler.GetCurrentProcess();
 
diff --git a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs
index fba22fc115..b6d2caf277 100644
--- a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs
+++ b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs
@@ -8,6 +8,7 @@ using Ryujinx.HLE.HOS.Kernel.Ipc;
 using Ryujinx.HLE.HOS.Kernel.Memory;
 using Ryujinx.HLE.HOS.Kernel.Process;
 using Ryujinx.HLE.HOS.Kernel.Threading;
+using System;
 using System.Collections.Generic;
 using System.Threading;
 
@@ -2139,30 +2140,84 @@ namespace Ryujinx.HLE.HOS.Kernel.SupervisorCall
         {
             handleIndex = 0;
 
-            if ((uint)handlesCount > 0x40)
+            if ((uint)handlesCount > KThread.MaxWaitSyncObjects)
             {
                 return KernelResult.MaximumExceeded;
             }
 
-            List<KSynchronizationObject> syncObjs = new List<KSynchronizationObject>();
+            KThread currentThread = _context.Scheduler.GetCurrentThread();
 
-            KProcess process = _context.Scheduler.GetCurrentProcess();
+            var syncObjs = new Span<KSynchronizationObject>(currentThread.WaitSyncObjects).Slice(0, handlesCount);
+
+            if (handlesCount != 0)
+            {
+                KProcess currentProcess = _context.Scheduler.GetCurrentProcess();
+
+                if (currentProcess.MemoryManager.AddrSpaceStart > handlesPtr)
+                {
+                    return KernelResult.UserCopyFailed;
+                }
+
+                long handlesSize = handlesCount * 4;
+
+                if (handlesPtr + (ulong)handlesSize <= handlesPtr)
+                {
+                    return KernelResult.UserCopyFailed;
+                }
+
+                if (handlesPtr + (ulong)handlesSize - 1 > currentProcess.MemoryManager.AddrSpaceEnd - 1)
+                {
+                    return KernelResult.UserCopyFailed;
+                }
+
+                Span<int> handles = new Span<int>(currentThread.WaitSyncHandles).Slice(0, handlesCount);
+
+                if (!KernelTransfer.UserToKernelInt32Array(_context, handlesPtr, handles))
+                {
+                    return KernelResult.UserCopyFailed;
+                }
+
+                int processedHandles = 0;
+
+                for (; processedHandles < handlesCount; processedHandles++)
+                {
+                    KSynchronizationObject syncObj = currentProcess.HandleTable.GetObject<KSynchronizationObject>(handles[processedHandles]);
+
+                    if (syncObj == null)
+                    {
+                        break;
+                    }
+
+                    syncObjs[processedHandles] = syncObj;
+
+                    syncObj.IncrementReferenceCount();
+                }
+
+                if (processedHandles != handlesCount)
+                {
+                    // One or more handles are invalid.
+                    for (int index = 0; index < processedHandles; index++)
+                    {
+                        currentThread.WaitSyncObjects[index].DecrementReferenceCount();
+                    }
+
+                    return KernelResult.InvalidHandle;
+                }
+            }
+
+            KernelResult result = _context.Synchronization.WaitFor(syncObjs, timeout, out handleIndex);
+
+            if (result == KernelResult.PortRemoteClosed)
+            {
+                result = KernelResult.Success;
+            }
 
             for (int index = 0; index < handlesCount; index++)
             {
-                int handle = process.CpuMemory.Read<int>(handlesPtr + (ulong)index * 4);
-
-                KSynchronizationObject syncObj = process.HandleTable.GetObject<KSynchronizationObject>(handle);
-
-                if (syncObj == null)
-                {
-                    break;
-                }
-
-                syncObjs.Add(syncObj);
+                currentThread.WaitSyncObjects[index].DecrementReferenceCount();
             }
 
-            return _context.Synchronization.WaitFor(syncObjs.ToArray(), timeout, out handleIndex);
+            return result;
         }
 
         public KernelResult CancelSynchronization(int handle)
diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs
index fa9b669ea3..22610b2205 100644
--- a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs
+++ b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs
@@ -1,4 +1,5 @@
 using Ryujinx.HLE.HOS.Kernel.Common;
+using System;
 using System.Collections.Generic;
 
 namespace Ryujinx.HLE.HOS.Kernel.Threading
@@ -12,7 +13,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
             _context = context;
         }
 
-        public KernelResult WaitFor(KSynchronizationObject[] syncObjs, long timeout, out int handleIndex)
+        public KernelResult WaitFor(Span<KSynchronizationObject> syncObjs, long timeout, out int handleIndex)
         {
             handleIndex = 0;
 
diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
index 754a1e530e..d4603178f2 100644
--- a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
+++ b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs
@@ -12,6 +12,8 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 {
     class KThread : KSynchronizationObject, IKFutureSchedulerObject
     {
+        public const int MaxWaitSyncObjects = 64;
+
         private int _hostThreadRunning;
 
         public Thread HostThread { get; private set; }
@@ -39,6 +41,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
         public ulong TlsAddress => _tlsAddress;
         public ulong TlsDramAddress { get; private set; }
 
+        public KSynchronizationObject[] WaitSyncObjects { get; }
+        public int[] WaitSyncHandles { get; }
+
         public long LastScheduledTime { get; set; }
 
         public LinkedListNode<KThread>[] SiblingsPerCore { get; private set; }
@@ -96,6 +101,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
             _scheduler      = KernelContext.Scheduler;
             _schedulingData = KernelContext.Scheduler.SchedulingData;
 
+            WaitSyncObjects = new KSynchronizationObject[MaxWaitSyncObjects];
+            WaitSyncHandles = new int[MaxWaitSyncObjects];
+
             SiblingsPerCore = new LinkedListNode<KThread>[KScheduler.CpuCoresCount];
 
             _mutexWaiters = new LinkedList<KThread>();