From f3835dc78bfc845786a38c189929ac8838960018 Mon Sep 17 00:00:00 2001
From: Mary-nyan <mary@mary.zone>
Date: Wed, 7 Sep 2022 22:37:15 +0200
Subject: [PATCH] bsd: implement SendMMsg and RecvMMsg (#3660)

* bsd: implement sendmmsg and recvmmsg

* Fix wrong increment of vlen
---
 .../HOS/Services/Sockets/Bsd/IClient.cs       |  85 +++++++
 .../HOS/Services/Sockets/Bsd/ISocket.cs       |   5 +
 .../Sockets/Bsd/Impl/ManagedSocket.cs         | 162 +++++++++++++
 .../Services/Sockets/Bsd/Types/BsdMMsgHdr.cs  |  56 +++++
 .../Services/Sockets/Bsd/Types/BsdMsgHdr.cs   | 212 ++++++++++++++++++
 .../HOS/Services/Sockets/Bsd/Types/TimeVal.cs |   8 +
 6 files changed, 528 insertions(+)
 create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
 create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
 create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs

diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
index 654844dc01..98a993119c 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
@@ -886,6 +886,91 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
             return WriteBsdResult(context, newSockFd, errno);
         }
 
+
+        [CommandHipc(29)] // 7.0.0+
+        // RecvMMsg(u32 fd, u32 vlen, u32 flags, u32 reserved, nn::socket::TimeVal timeout) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
+        public ResultCode RecvMMsg(ServiceCtx context)
+        {
+            int            socketFd    = context.RequestData.ReadInt32();
+            int            vlen        = context.RequestData.ReadInt32();
+            BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
+            uint           reserved    = context.RequestData.ReadUInt32();
+            TimeVal        timeout     = context.RequestData.ReadStruct<TimeVal>();
+
+            ulong receivePosition = context.Request.ReceiveBuff[0].Position;
+            ulong receiveLength = context.Request.ReceiveBuff[0].Size;
+
+            WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
+
+            LinuxError errno  = LinuxError.EBADF;
+            ISocket    socket = _context.RetrieveSocket(socketFd);
+            int        result = -1;
+
+            if (socket != null)
+            {
+                errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
+
+                if (errno == LinuxError.SUCCESS)
+                {
+                    errno = socket.RecvMMsg(out result, message, socketFlags, timeout);
+
+                    if (errno == LinuxError.SUCCESS)
+                    {
+                        errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
+                    }
+                }
+            }
+
+            if (errno == LinuxError.SUCCESS)
+            {
+                SetResultErrno(socket, result);
+                receiveRegion.Dispose();
+            }
+
+            return WriteBsdResult(context, result, errno);
+        }
+
+        [CommandHipc(30)] // 7.0.0+
+        // SendMMsg(u32 fd, u32 vlen, u32 flags) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
+        public ResultCode SendMMsg(ServiceCtx context)
+        {
+            int            socketFd    = context.RequestData.ReadInt32();
+            int            vlen        = context.RequestData.ReadInt32();
+            BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
+
+            ulong receivePosition = context.Request.ReceiveBuff[0].Position;
+            ulong receiveLength = context.Request.ReceiveBuff[0].Size;
+
+            WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
+
+            LinuxError errno  = LinuxError.EBADF;
+            ISocket    socket = _context.RetrieveSocket(socketFd);
+            int        result = -1;
+
+            if (socket != null)
+            {
+                errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
+
+                if (errno == LinuxError.SUCCESS)
+                {
+                    errno = socket.SendMMsg(out result, message, socketFlags);
+
+                    if (errno == LinuxError.SUCCESS)
+                    {
+                        errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
+                    }
+                }
+            }
+
+            if (errno == LinuxError.SUCCESS)
+            {
+                SetResultErrno(socket, result);
+                receiveRegion.Dispose();
+            }
+
+            return WriteBsdResult(context, result, errno);
+        }
+
         [CommandHipc(31)] // 7.0.0+
         // EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno)
         public ResultCode EventFd(ServiceCtx context)
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
index ee6bd9e8e5..b4f2bff196 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs
@@ -25,7 +25,12 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
 
         LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint);
 
+        LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout);
+
+        LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags);
+
         LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue);
+
         LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue);
 
         bool Poll(int microSeconds, SelectMode mode);
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
index d2a8345887..1b6ede86cb 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
@@ -1,5 +1,7 @@
 using Ryujinx.Common.Logging;
 using System;
+using System.Collections.Generic;
+using System.Diagnostics;
 using System.Net;
 using System.Net.Sockets;
 using System.Runtime.InteropServices;
@@ -356,5 +358,165 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
         {
             return Send(out writeSize, buffer, BsdSocketFlags.None);
         }
+
+        private bool CanSupportMMsgHdr(BsdMMsgHdr message)
+        {
+            for (int i = 0; i < message.Messages.Length; i++)
+            {
+                if (message.Messages[i].Name != null ||
+                    message.Messages[i].Control != null)
+                {
+                    return false;
+                }
+            }
+
+            return true;
+        }
+
+        private static IList<ArraySegment<byte>> ConvertMessagesToBuffer(BsdMMsgHdr message)
+        {
+            int segmentCount = 0;
+            int index = 0;
+
+            foreach (BsdMsgHdr msgHeader in message.Messages)
+            {
+                segmentCount += msgHeader.Iov.Length;
+            }
+
+            ArraySegment<byte>[] buffers = new ArraySegment<byte>[segmentCount];
+
+            foreach (BsdMsgHdr msgHeader in message.Messages)
+            {
+                foreach (byte[] iov in msgHeader.Iov)
+                {
+                    buffers[index++] = new ArraySegment<byte>(iov);
+                }
+
+                // Clear the length
+                msgHeader.Length = 0;
+            }
+
+            return buffers;
+        }
+
+        private static void UpdateMessages(out int vlen, BsdMMsgHdr message, int transferedSize)
+        {
+            int bytesLeft = transferedSize;
+            int index = 0;
+
+            while (bytesLeft > 0)
+            {
+                // First ensure we haven't finished all buffers
+                if (index >= message.Messages.Length)
+                {
+                    break;
+                }
+
+                BsdMsgHdr msgHeader = message.Messages[index];
+
+                int possiblyTransferedBytes = 0;
+
+                foreach (byte[] iov in msgHeader.Iov)
+                {
+                    possiblyTransferedBytes += iov.Length;
+                }
+
+                int storedBytes;
+
+                if (bytesLeft > possiblyTransferedBytes)
+                {
+                    storedBytes = possiblyTransferedBytes;
+                    index++;
+                }
+                else
+                {
+                    storedBytes = bytesLeft;
+                }
+
+                msgHeader.Length = (uint)storedBytes;
+                bytesLeft -= storedBytes;
+            }
+
+            Debug.Assert(bytesLeft == 0);
+
+            vlen = index + 1;
+        }
+
+        // TODO: Find a way to support passing the timeout somehow without changing the socket ReceiveTimeout.
+        public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
+        {
+            vlen = 0;
+
+            if (message.Messages.Length == 0)
+            {
+                return LinuxError.SUCCESS;
+            }
+
+            if (!CanSupportMMsgHdr(message))
+            {
+                Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
+
+                return LinuxError.EOPNOTSUPP;
+            }
+
+            if (message.Messages.Length == 0)
+            {
+                return LinuxError.SUCCESS;
+            }
+
+            try
+            {
+                int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+
+                if (receiveSize > 0)
+                {
+                    UpdateMessages(out vlen, message, receiveSize);
+                }
+
+                return WinSockHelper.ConvertError((WsaError)socketError);
+            }
+            catch (SocketException exception)
+            {
+                return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+            }
+        }
+
+        public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
+        {
+            vlen = 0;
+
+            if (message.Messages.Length == 0)
+            {
+                return LinuxError.SUCCESS;
+            }
+
+            if (!CanSupportMMsgHdr(message))
+            {
+                Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
+
+                return LinuxError.EOPNOTSUPP;
+            }
+
+            if (message.Messages.Length == 0)
+            {
+                return LinuxError.SUCCESS;
+            }
+
+            try
+            {
+                int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+
+                if (sendSize > 0)
+                {
+                    UpdateMessages(out vlen, message, sendSize);
+                }
+
+                return WinSockHelper.ConvertError((WsaError)socketError);
+            }
+            catch (SocketException exception)
+            {
+                return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+            }
+        }
     }
 }
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
new file mode 100644
index 0000000000..bfcc92cd86
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
@@ -0,0 +1,56 @@
+using System;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+    class BsdMMsgHdr
+    {
+        public BsdMsgHdr[] Messages { get; }
+
+        private BsdMMsgHdr(BsdMsgHdr[] messages)
+        {
+            Messages = messages;
+        }
+
+        public static LinuxError Serialize(Span<byte> rawData, BsdMMsgHdr message)
+        {
+            rawData[0] = 0x8;
+            rawData = rawData[1..];
+
+            for (int index = 0; index < message.Messages.Length; index++)
+            {
+                LinuxError res = BsdMsgHdr.Serialize(ref rawData, message.Messages[index]);
+
+                if (res != LinuxError.SUCCESS)
+                {
+                    return res;
+                }
+            }
+
+            return LinuxError.SUCCESS;
+        }
+
+        public static LinuxError Deserialize(out BsdMMsgHdr message, ReadOnlySpan<byte> rawData, int vlen)
+        {
+            message = null;
+
+            BsdMsgHdr[] messages = new BsdMsgHdr[vlen];
+
+            // Skip "header" byte (Nintendo also ignore it)
+            rawData = rawData[1..];
+
+            for (int index = 0; index < messages.Length; index++)
+            {
+                LinuxError res = BsdMsgHdr.Deserialize(out messages[index], ref rawData);
+
+                if (res != LinuxError.SUCCESS)
+                {
+                    return res;
+                }
+            }
+
+            message = new BsdMMsgHdr(messages);
+
+            return LinuxError.SUCCESS;
+        }
+    }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
new file mode 100644
index 0000000000..bb620375c7
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
@@ -0,0 +1,212 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+    class BsdMsgHdr
+    {
+        public byte[] Name { get; }
+        public byte[][] Iov { get; }
+        public byte[] Control { get; }
+        public BsdSocketFlags Flags { get; }
+        public uint Length;
+
+        private BsdMsgHdr(byte[] name, byte[][] iov, byte[] control, BsdSocketFlags flags, uint length)
+        {
+            Name = name;
+            Iov = iov;
+            Control = control;
+            Flags = flags;
+            Length = length;
+        }
+
+        public static LinuxError Serialize(ref Span<byte> rawData, BsdMsgHdr message)
+        {
+            int msgNameLength = message.Name == null ? 0 : message.Name.Length;
+            int iovCount = message.Iov == null ? 0 : message.Iov.Length;
+            int controlLength = message.Control == null ? 0 : message.Control.Length;
+            BsdSocketFlags flags = message.Flags;
+
+            if (!MemoryMarshal.TryWrite(rawData, ref msgNameLength))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (msgNameLength > 0)
+            {
+                if (rawData.Length < msgNameLength)
+                {
+                    return LinuxError.EFAULT;
+                }
+
+                message.Name.CopyTo(rawData);
+                rawData = rawData[msgNameLength..];
+            }
+
+            if (!MemoryMarshal.TryWrite(rawData, ref iovCount))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (iovCount > 0)
+            {
+                for (int index = 0; index < iovCount; index++)
+                {
+                    ulong iovLength = (ulong)message.Iov[index].Length;
+
+                    if (!MemoryMarshal.TryWrite(rawData, ref iovLength))
+                    {
+                        return LinuxError.EFAULT;
+                    }
+
+                    rawData = rawData[sizeof(ulong)..];
+
+                    if (iovLength > 0)
+                    {
+                        if ((ulong)rawData.Length < iovLength)
+                        {
+                            return LinuxError.EFAULT;
+                        }
+
+                        message.Iov[index].CopyTo(rawData);
+                        rawData = rawData[(int)iovLength..];
+                    }
+                }
+            }
+
+            if (!MemoryMarshal.TryWrite(rawData, ref controlLength))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (controlLength > 0)
+            {
+                if (rawData.Length < controlLength)
+                {
+                    return LinuxError.EFAULT;
+                }
+
+                message.Control.CopyTo(rawData);
+                rawData = rawData[controlLength..];
+            }
+
+            if (!MemoryMarshal.TryWrite(rawData, ref flags))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(BsdSocketFlags)..];
+
+            if (!MemoryMarshal.TryWrite(rawData, ref message.Length))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            return LinuxError.SUCCESS;
+        }
+
+        public static LinuxError Deserialize(out BsdMsgHdr message, ref ReadOnlySpan<byte> rawData)
+        {
+            byte[] name = null;
+            byte[][] iov = null;
+            byte[] control = null;
+
+            message = null;
+
+            if (!MemoryMarshal.TryRead(rawData, out uint msgNameLength))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (msgNameLength > 0)
+            {
+                if (rawData.Length < msgNameLength)
+                {
+                    return LinuxError.EFAULT;
+                }
+
+                name = rawData[..(int)msgNameLength].ToArray();
+                rawData = rawData[(int)msgNameLength..];
+            }
+
+            if (!MemoryMarshal.TryRead(rawData, out uint iovCount))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (iovCount > 0)
+            {
+                iov = new byte[iovCount][];
+
+                for (int index = 0; index < iov.Length; index++)
+                {
+                    if (!MemoryMarshal.TryRead(rawData, out ulong iovLength))
+                    {
+                        return LinuxError.EFAULT;
+                    }
+
+                    rawData = rawData[sizeof(ulong)..];
+
+                    if (iovLength > 0)
+                    {
+                        if ((ulong)rawData.Length < iovLength)
+                        {
+                            return LinuxError.EFAULT;
+                        }
+
+                        iov[index] = rawData[..(int)iovLength].ToArray();
+                        rawData = rawData[(int)iovLength..];
+                    }
+                }
+            }
+
+            if (!MemoryMarshal.TryRead(rawData, out uint controlLength))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            if (controlLength > 0)
+            {
+                if (rawData.Length < controlLength)
+                {
+                    return LinuxError.EFAULT;
+                }
+
+                control = rawData[..(int)controlLength].ToArray();
+                rawData = rawData[(int)controlLength..];
+            }
+
+            if (!MemoryMarshal.TryRead(rawData, out BsdSocketFlags flags))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(BsdSocketFlags)..];
+
+            if (!MemoryMarshal.TryRead(rawData, out uint length))
+            {
+                return LinuxError.EFAULT;
+            }
+
+            rawData = rawData[sizeof(uint)..];
+
+            message = new BsdMsgHdr(name, iov, control, flags, length);
+
+            return LinuxError.SUCCESS;
+        }
+    }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
new file mode 100644
index 0000000000..c577660235
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
@@ -0,0 +1,8 @@
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
+{
+    public struct TimeVal
+    {
+        public ulong TvSec;
+        public ulong TvUsec;
+    }
+}