From b94dc01d4357c608d60383e85f709312294fd833 Mon Sep 17 00:00:00 2001
From: Mary <me@thog.eu>
Date: Wed, 5 May 2021 23:44:26 +0200
Subject: [PATCH] SM instance & TIPC fixes (#2241)

This PR addresses the following issues:
- SM was previously instancied once and reused on all sessions. This
  could cause inconsistency on the service initialization.
- TIPC replies were not matching what is generated on hardware.
---
 Ryujinx.HLE/HOS/Horizon.cs                    |  1 +
 Ryujinx.HLE/HOS/Ipc/IpcMessage.cs             | 48 +++++++++++++++++++
 Ryujinx.HLE/HOS/Services/ServerBase.cs        | 27 +++++++----
 Ryujinx.HLE/HOS/Services/Sm/IUserInterface.cs | 24 +++++++---
 4 files changed, 85 insertions(+), 15 deletions(-)

diff --git a/Ryujinx.HLE/HOS/Horizon.cs b/Ryujinx.HLE/HOS/Horizon.cs
index fa88d775e4..c240d1351c 100644
--- a/Ryujinx.HLE/HOS/Horizon.cs
+++ b/Ryujinx.HLE/HOS/Horizon.cs
@@ -274,6 +274,7 @@ namespace Ryujinx.HLE.HOS
         public void InitializeServices()
         {
             IUserInterface sm = new IUserInterface(KernelContext);
+            sm.TrySetServer(new ServerBase(KernelContext, "SmServer") { SmObjectFactory = () => new IUserInterface(KernelContext) });
 
             // Wait until SM server thread is done with initialization,
             // only then doing connections to SM is safe.
diff --git a/Ryujinx.HLE/HOS/Ipc/IpcMessage.cs b/Ryujinx.HLE/HOS/Ipc/IpcMessage.cs
index c99ad622b9..e5b9bf046b 100644
--- a/Ryujinx.HLE/HOS/Ipc/IpcMessage.cs
+++ b/Ryujinx.HLE/HOS/Ipc/IpcMessage.cs
@@ -1,4 +1,5 @@
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.IO;
 
 namespace Ryujinx.HLE.HOS.Ipc
@@ -185,6 +186,53 @@ namespace Ryujinx.HLE.HOS.Ipc
             }
         }
 
+        public byte[] GetBytesTipc()
+        {
+            Debug.Assert(PtrBuff.Count == 0);
+
+            using (MemoryStream ms = new MemoryStream())
+            {
+                BinaryWriter writer = new BinaryWriter(ms);
+
+                int word0;
+                int word1;
+
+                word0 = (int)Type;
+                word0 |= (SendBuff.Count & 0xf) << 20;
+                word0 |= (ReceiveBuff.Count & 0xf) << 24;
+                word0 |= (ExchangeBuff.Count & 0xf) << 28;
+
+                byte[] handleData = new byte[0];
+
+                if (HandleDesc != null)
+                {
+                    handleData = HandleDesc.GetBytes();
+                }
+
+                int dataLength = RawData?.Length ?? 0;
+
+                dataLength = ((dataLength + 3) & ~3) / 4;
+
+                word1 = (dataLength & 0x3ff);
+
+                if (HandleDesc != null)
+                {
+                    word1 |= 1 << 31;
+                }
+
+                writer.Write(word0);
+                writer.Write(word1);
+                writer.Write(handleData);
+
+                if (RawData != null)
+                {
+                    writer.Write(RawData);
+                }
+
+                return ms.ToArray();
+            }
+        }
+
         private long GetPadSize16(long position)
         {
             if ((position & 0xf) != 0)
diff --git a/Ryujinx.HLE/HOS/Services/ServerBase.cs b/Ryujinx.HLE/HOS/Services/ServerBase.cs
index f957a624b6..5b9834dcb1 100644
--- a/Ryujinx.HLE/HOS/Services/ServerBase.cs
+++ b/Ryujinx.HLE/HOS/Services/ServerBase.cs
@@ -35,10 +35,10 @@ namespace Ryujinx.HLE.HOS.Services
         private readonly List<int> _sessionHandles = new List<int>();
         private readonly List<int> _portHandles = new List<int>();
         private readonly Dictionary<int, IpcService> _sessions = new Dictionary<int, IpcService>();
-        private readonly Dictionary<int, IpcService> _ports = new Dictionary<int, IpcService>();
+        private readonly Dictionary<int, Func<IpcService>> _ports = new Dictionary<int, Func<IpcService>>();
 
         public ManualResetEvent InitDone { get; }
-        public IpcService SmObject { get; set; }
+        public Func<IpcService> SmObjectFactory { get; set; }
         public string Name { get; }
 
         public ServerBase(KernelContext context, string name)
@@ -58,10 +58,10 @@ namespace Ryujinx.HLE.HOS.Services
             KernelStatic.StartInitialProcess(context, creationInfo, DefaultCapabilities, 44, ServerLoop);
         }
 
-        private void AddPort(int serverPortHandle, IpcService obj)
+        private void AddPort(int serverPortHandle, Func<IpcService> objectFactory)
         {
             _portHandles.Add(serverPortHandle);
-            _ports.Add(serverPortHandle, obj);
+            _ports.Add(serverPortHandle, objectFactory);
         }
 
         public void AddSessionObj(KServerSession serverSession, IpcService obj)
@@ -80,11 +80,11 @@ namespace Ryujinx.HLE.HOS.Services
         {
             _selfProcess = KernelStatic.GetCurrentProcess();
 
-            if (SmObject != null)
+            if (SmObjectFactory != null)
             {
                 _context.Syscall.ManageNamedPort("sm:", 50, out int serverPortHandle);
 
-                AddPort(serverPortHandle, SmObject);
+                AddPort(serverPortHandle, SmObjectFactory);
 
                 InitDone.Set();
             }
@@ -141,7 +141,9 @@ namespace Ryujinx.HLE.HOS.Services
                         // We got a new connection, accept the session to allow servicing future requests.
                         if (_context.Syscall.AcceptSession(handles[signaledIndex], out int serverSessionHandle) == KernelResult.Success)
                         {
-                            AddSessionObj(serverSessionHandle, _ports[handles[signaledIndex]]);
+                            IpcService obj = _ports[handles[signaledIndex]].Invoke();
+
+                            AddSessionObj(serverSessionHandle, obj);
                         }
                     }
 
@@ -191,6 +193,7 @@ namespace Ryujinx.HLE.HOS.Services
             }
 
             bool shouldReply = true;
+            bool isTipcCommunication = false;
 
             using (MemoryStream raw = new MemoryStream(request.RawData))
             {
@@ -269,6 +272,8 @@ namespace Ryujinx.HLE.HOS.Services
                 // If the type is past 0xF, we are using TIPC
                 else if (request.Type > IpcMessageType.TipcCloseSession)
                 {
+                    isTipcCommunication = true;
+
                     // Response type is always the same as request on TIPC.
                     response.Type = request.Type;
 
@@ -290,13 +295,19 @@ namespace Ryujinx.HLE.HOS.Services
 
                         response.RawData = resMs.ToArray();
                     }
+
+                    process.CpuMemory.Write(messagePtr, response.GetBytesTipc());
                 }
                 else
                 {
                     throw new NotImplementedException(request.Type.ToString());
                 }
 
-                process.CpuMemory.Write(messagePtr, response.GetBytes((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48)));
+                if (!isTipcCommunication)
+                {
+                    process.CpuMemory.Write(messagePtr, response.GetBytes((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48)));
+                }
+
                 return shouldReply;
             }
         }
diff --git a/Ryujinx.HLE/HOS/Services/Sm/IUserInterface.cs b/Ryujinx.HLE/HOS/Services/Sm/IUserInterface.cs
index d5dabb2d01..9a0ccbc35e 100644
--- a/Ryujinx.HLE/HOS/Services/Sm/IUserInterface.cs
+++ b/Ryujinx.HLE/HOS/Services/Sm/IUserInterface.cs
@@ -15,15 +15,20 @@ namespace Ryujinx.HLE.HOS.Services.Sm
 {
     class IUserInterface : IpcService
     {
-        private Dictionary<string, Type> _services;
+        private static Dictionary<string, Type> _services;
 
-        private readonly ConcurrentDictionary<string, KPort> _registeredServices;
+        private static readonly ConcurrentDictionary<string, KPort> _registeredServices;
 
         private readonly ServerBase _commonServer;
 
         private bool _isInitialized;
 
         public IUserInterface(KernelContext context)
+        {
+            _commonServer = new ServerBase(context, "CommonServer");
+        }
+
+        static IUserInterface()
         {
             _registeredServices = new ConcurrentDictionary<string, KPort>();
 
@@ -31,10 +36,6 @@ namespace Ryujinx.HLE.HOS.Services.Sm
                 .SelectMany(type => type.GetCustomAttributes(typeof(ServiceAttribute), true)
                 .Select(service => (((ServiceAttribute)service).Name, type)))
                 .ToDictionary(service => service.Name, service => service.type);
-
-            TrySetServer(new ServerBase(context, "SmServer") { SmObject = this });
-
-            _commonServer = new ServerBase(context, "CommonServer");
         }
 
         [CommandHipc(0)]
@@ -47,9 +48,16 @@ namespace Ryujinx.HLE.HOS.Services.Sm
             return ResultCode.Success;
         }
 
-        [CommandHipc(1)]
         [CommandTipc(1)] // 12.0.0+
         // GetService(ServiceName name) -> handle<move, session>
+        public ResultCode GetServiceTipc(ServiceCtx context)
+        {
+            context.Response.HandleDesc = IpcHandleDesc.MakeMove(0);
+
+            return GetService(context);
+        }
+
+        [CommandHipc(1)]
         public ResultCode GetService(ServiceCtx context)
         {
             if (!_isInitialized)
@@ -142,6 +150,8 @@ namespace Ryujinx.HLE.HOS.Services.Sm
         {
             if (!_isInitialized)
             {
+                context.Response.HandleDesc = IpcHandleDesc.MakeMove(0);
+
                 return ResultCode.NotInitialized;
             }