From f3835dc78bfc845786a38c189929ac8838960018 Mon Sep 17 00:00:00 2001 From: Mary-nyan Date: Wed, 7 Sep 2022 22:37:15 +0200 Subject: [PATCH] bsd: implement SendMMsg and RecvMMsg (#3660) * bsd: implement sendmmsg and recvmmsg * Fix wrong increment of vlen --- .../HOS/Services/Sockets/Bsd/IClient.cs | 85 +++++++ .../HOS/Services/Sockets/Bsd/ISocket.cs | 5 + .../Sockets/Bsd/Impl/ManagedSocket.cs | 162 +++++++++++++ .../Services/Sockets/Bsd/Types/BsdMMsgHdr.cs | 56 +++++ .../Services/Sockets/Bsd/Types/BsdMsgHdr.cs | 212 ++++++++++++++++++ .../HOS/Services/Sockets/Bsd/Types/TimeVal.cs | 8 + 6 files changed, 528 insertions(+) create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs index 654844dc01..98a993119c 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs @@ -886,6 +886,91 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd return WriteBsdResult(context, newSockFd, errno); } + + [CommandHipc(29)] // 7.0.0+ + // RecvMMsg(u32 fd, u32 vlen, u32 flags, u32 reserved, nn::socket::TimeVal timeout) -> (i32 ret, u32 bsd_errno, buffer message); + public ResultCode RecvMMsg(ServiceCtx context) + { + int socketFd = context.RequestData.ReadInt32(); + int vlen = context.RequestData.ReadInt32(); + BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32(); + uint reserved = context.RequestData.ReadUInt32(); + TimeVal timeout = context.RequestData.ReadStruct(); + + ulong receivePosition = context.Request.ReceiveBuff[0].Position; + ulong receiveLength = context.Request.ReceiveBuff[0].Size; + + WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength); + + LinuxError errno = LinuxError.EBADF; + ISocket socket = _context.RetrieveSocket(socketFd); + int result = -1; + + if (socket != null) + { + errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen); + + if (errno == LinuxError.SUCCESS) + { + errno = socket.RecvMMsg(out result, message, socketFlags, timeout); + + if (errno == LinuxError.SUCCESS) + { + errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message); + } + } + } + + if (errno == LinuxError.SUCCESS) + { + SetResultErrno(socket, result); + receiveRegion.Dispose(); + } + + return WriteBsdResult(context, result, errno); + } + + [CommandHipc(30)] // 7.0.0+ + // SendMMsg(u32 fd, u32 vlen, u32 flags) -> (i32 ret, u32 bsd_errno, buffer message); + public ResultCode SendMMsg(ServiceCtx context) + { + int socketFd = context.RequestData.ReadInt32(); + int vlen = context.RequestData.ReadInt32(); + BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32(); + + ulong receivePosition = context.Request.ReceiveBuff[0].Position; + ulong receiveLength = context.Request.ReceiveBuff[0].Size; + + WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength); + + LinuxError errno = LinuxError.EBADF; + ISocket socket = _context.RetrieveSocket(socketFd); + int result = -1; + + if (socket != null) + { + errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen); + + if (errno == LinuxError.SUCCESS) + { + errno = socket.SendMMsg(out result, message, socketFlags); + + if (errno == LinuxError.SUCCESS) + { + errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message); + } + } + } + + if (errno == LinuxError.SUCCESS) + { + SetResultErrno(socket, result); + receiveRegion.Dispose(); + } + + return WriteBsdResult(context, result, errno); + } + [CommandHipc(31)] // 7.0.0+ // EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno) public ResultCode EventFd(ServiceCtx context) diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs index ee6bd9e8e5..b4f2bff196 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs @@ -25,7 +25,12 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint); + LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout); + + LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags); + LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue); + LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue); bool Poll(int microSeconds, SelectMode mode); diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs index d2a8345887..1b6ede86cb 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs @@ -1,5 +1,7 @@ using Ryujinx.Common.Logging; using System; +using System.Collections.Generic; +using System.Diagnostics; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -356,5 +358,165 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { return Send(out writeSize, buffer, BsdSocketFlags.None); } + + private bool CanSupportMMsgHdr(BsdMMsgHdr message) + { + for (int i = 0; i < message.Messages.Length; i++) + { + if (message.Messages[i].Name != null || + message.Messages[i].Control != null) + { + return false; + } + } + + return true; + } + + private static IList> ConvertMessagesToBuffer(BsdMMsgHdr message) + { + int segmentCount = 0; + int index = 0; + + foreach (BsdMsgHdr msgHeader in message.Messages) + { + segmentCount += msgHeader.Iov.Length; + } + + ArraySegment[] buffers = new ArraySegment[segmentCount]; + + foreach (BsdMsgHdr msgHeader in message.Messages) + { + foreach (byte[] iov in msgHeader.Iov) + { + buffers[index++] = new ArraySegment(iov); + } + + // Clear the length + msgHeader.Length = 0; + } + + return buffers; + } + + private static void UpdateMessages(out int vlen, BsdMMsgHdr message, int transferedSize) + { + int bytesLeft = transferedSize; + int index = 0; + + while (bytesLeft > 0) + { + // First ensure we haven't finished all buffers + if (index >= message.Messages.Length) + { + break; + } + + BsdMsgHdr msgHeader = message.Messages[index]; + + int possiblyTransferedBytes = 0; + + foreach (byte[] iov in msgHeader.Iov) + { + possiblyTransferedBytes += iov.Length; + } + + int storedBytes; + + if (bytesLeft > possiblyTransferedBytes) + { + storedBytes = possiblyTransferedBytes; + index++; + } + else + { + storedBytes = bytesLeft; + } + + msgHeader.Length = (uint)storedBytes; + bytesLeft -= storedBytes; + } + + Debug.Assert(bytesLeft == 0); + + vlen = index + 1; + } + + // TODO: Find a way to support passing the timeout somehow without changing the socket ReceiveTimeout. + public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout) + { + vlen = 0; + + if (message.Messages.Length == 0) + { + return LinuxError.SUCCESS; + } + + if (!CanSupportMMsgHdr(message)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr"); + + return LinuxError.EOPNOTSUPP; + } + + if (message.Messages.Length == 0) + { + return LinuxError.SUCCESS; + } + + try + { + int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError); + + if (receiveSize > 0) + { + UpdateMessages(out vlen, message, receiveSize); + } + + return WinSockHelper.ConvertError((WsaError)socketError); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags) + { + vlen = 0; + + if (message.Messages.Length == 0) + { + return LinuxError.SUCCESS; + } + + if (!CanSupportMMsgHdr(message)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr"); + + return LinuxError.EOPNOTSUPP; + } + + if (message.Messages.Length == 0) + { + return LinuxError.SUCCESS; + } + + try + { + int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError); + + if (sendSize > 0) + { + UpdateMessages(out vlen, message, sendSize); + } + + return WinSockHelper.ConvertError((WsaError)socketError); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } } } diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs new file mode 100644 index 0000000000..bfcc92cd86 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs @@ -0,0 +1,56 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class BsdMMsgHdr + { + public BsdMsgHdr[] Messages { get; } + + private BsdMMsgHdr(BsdMsgHdr[] messages) + { + Messages = messages; + } + + public static LinuxError Serialize(Span rawData, BsdMMsgHdr message) + { + rawData[0] = 0x8; + rawData = rawData[1..]; + + for (int index = 0; index < message.Messages.Length; index++) + { + LinuxError res = BsdMsgHdr.Serialize(ref rawData, message.Messages[index]); + + if (res != LinuxError.SUCCESS) + { + return res; + } + } + + return LinuxError.SUCCESS; + } + + public static LinuxError Deserialize(out BsdMMsgHdr message, ReadOnlySpan rawData, int vlen) + { + message = null; + + BsdMsgHdr[] messages = new BsdMsgHdr[vlen]; + + // Skip "header" byte (Nintendo also ignore it) + rawData = rawData[1..]; + + for (int index = 0; index < messages.Length; index++) + { + LinuxError res = BsdMsgHdr.Deserialize(out messages[index], ref rawData); + + if (res != LinuxError.SUCCESS) + { + return res; + } + } + + message = new BsdMMsgHdr(messages); + + return LinuxError.SUCCESS; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs new file mode 100644 index 0000000000..bb620375c7 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs @@ -0,0 +1,212 @@ +using System; +using System.Runtime.InteropServices; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class BsdMsgHdr + { + public byte[] Name { get; } + public byte[][] Iov { get; } + public byte[] Control { get; } + public BsdSocketFlags Flags { get; } + public uint Length; + + private BsdMsgHdr(byte[] name, byte[][] iov, byte[] control, BsdSocketFlags flags, uint length) + { + Name = name; + Iov = iov; + Control = control; + Flags = flags; + Length = length; + } + + public static LinuxError Serialize(ref Span rawData, BsdMsgHdr message) + { + int msgNameLength = message.Name == null ? 0 : message.Name.Length; + int iovCount = message.Iov == null ? 0 : message.Iov.Length; + int controlLength = message.Control == null ? 0 : message.Control.Length; + BsdSocketFlags flags = message.Flags; + + if (!MemoryMarshal.TryWrite(rawData, ref msgNameLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (msgNameLength > 0) + { + if (rawData.Length < msgNameLength) + { + return LinuxError.EFAULT; + } + + message.Name.CopyTo(rawData); + rawData = rawData[msgNameLength..]; + } + + if (!MemoryMarshal.TryWrite(rawData, ref iovCount)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (iovCount > 0) + { + for (int index = 0; index < iovCount; index++) + { + ulong iovLength = (ulong)message.Iov[index].Length; + + if (!MemoryMarshal.TryWrite(rawData, ref iovLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(ulong)..]; + + if (iovLength > 0) + { + if ((ulong)rawData.Length < iovLength) + { + return LinuxError.EFAULT; + } + + message.Iov[index].CopyTo(rawData); + rawData = rawData[(int)iovLength..]; + } + } + } + + if (!MemoryMarshal.TryWrite(rawData, ref controlLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (controlLength > 0) + { + if (rawData.Length < controlLength) + { + return LinuxError.EFAULT; + } + + message.Control.CopyTo(rawData); + rawData = rawData[controlLength..]; + } + + if (!MemoryMarshal.TryWrite(rawData, ref flags)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(BsdSocketFlags)..]; + + if (!MemoryMarshal.TryWrite(rawData, ref message.Length)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + return LinuxError.SUCCESS; + } + + public static LinuxError Deserialize(out BsdMsgHdr message, ref ReadOnlySpan rawData) + { + byte[] name = null; + byte[][] iov = null; + byte[] control = null; + + message = null; + + if (!MemoryMarshal.TryRead(rawData, out uint msgNameLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (msgNameLength > 0) + { + if (rawData.Length < msgNameLength) + { + return LinuxError.EFAULT; + } + + name = rawData[..(int)msgNameLength].ToArray(); + rawData = rawData[(int)msgNameLength..]; + } + + if (!MemoryMarshal.TryRead(rawData, out uint iovCount)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (iovCount > 0) + { + iov = new byte[iovCount][]; + + for (int index = 0; index < iov.Length; index++) + { + if (!MemoryMarshal.TryRead(rawData, out ulong iovLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(ulong)..]; + + if (iovLength > 0) + { + if ((ulong)rawData.Length < iovLength) + { + return LinuxError.EFAULT; + } + + iov[index] = rawData[..(int)iovLength].ToArray(); + rawData = rawData[(int)iovLength..]; + } + } + } + + if (!MemoryMarshal.TryRead(rawData, out uint controlLength)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + if (controlLength > 0) + { + if (rawData.Length < controlLength) + { + return LinuxError.EFAULT; + } + + control = rawData[..(int)controlLength].ToArray(); + rawData = rawData[(int)controlLength..]; + } + + if (!MemoryMarshal.TryRead(rawData, out BsdSocketFlags flags)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(BsdSocketFlags)..]; + + if (!MemoryMarshal.TryRead(rawData, out uint length)) + { + return LinuxError.EFAULT; + } + + rawData = rawData[sizeof(uint)..]; + + message = new BsdMsgHdr(name, iov, control, flags, length); + + return LinuxError.SUCCESS; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs new file mode 100644 index 0000000000..c577660235 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs @@ -0,0 +1,8 @@ +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + public struct TimeVal + { + public ulong TvSec; + public ulong TvUsec; + } +}