Archived
1
0
Fork 0
forked from Mirror/Ryujinx

ssl: Implement SSL connectivity (#2961)

* implement certain servicessl functions

* ssl: Implement more of SSL connection and abstract it

This adds support to non blocking SSL operations and unlink the SSL
implementation from the IPC logic.

* Rename SslDefaultSocketConnection to SslManagedSocketConnection

* Fix regression on Pokemon TV

* Address gdkchan's comment

* Simplify value read from previous commit

* ssl: some changes

- Implement builtin certificates parsing and retrieving
- Fix issues with SSL version handling
- Improve managed SSL socket error handling
- Ensure to only return a certificate on DoHandshake when actually requested

* Add missing BuiltInCertificateManager initialization call

* Address gdkchan's comment

* Address Ack's comment

Co-authored-by: InvoxiPlayGames <webmaster@invoxiplaygames.uk>
This commit is contained in:
Mary 2022-01-13 23:29:04 +01:00 committed by GitHub
parent 366fe2dbb2
commit 3fa7ef21b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 1138 additions and 34 deletions

View file

@ -10,6 +10,7 @@ using LibHac.Tools.FsSystem.NcaUtils;
using LibHac.Tools.Ncm; using LibHac.Tools.Ncm;
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.Exceptions; using Ryujinx.HLE.Exceptions;
using Ryujinx.HLE.HOS.Services.Ssl;
using Ryujinx.HLE.HOS.Services.Time; using Ryujinx.HLE.HOS.Services.Time;
using Ryujinx.HLE.Utilities; using Ryujinx.HLE.Utilities;
using System; using System;
@ -195,6 +196,7 @@ namespace Ryujinx.HLE.FileSystem.Content
if (device != null) if (device != null)
{ {
TimeManager.Instance.InitializeTimeZone(device); TimeManager.Instance.InitializeTimeZone(device);
BuiltInCertificateManager.Instance.Initialize(device);
device.System.SharedFontManager.Initialize(); device.System.SharedFontManager.Initialize();
} }
} }

View file

@ -0,0 +1,237 @@
using LibHac;
using LibHac.Common;
using LibHac.Fs;
using LibHac.Fs.Fsa;
using LibHac.FsSystem;
using LibHac.Tools.FsSystem;
using LibHac.Tools.FsSystem.NcaUtils;
using Ryujinx.Common.Configuration;
using Ryujinx.Common.Logging;
using Ryujinx.HLE.Exceptions;
using Ryujinx.HLE.FileSystem;
using Ryujinx.HLE.FileSystem.Content;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Ssl
{
class BuiltInCertificateManager
{
private const long CertStoreTitleId = 0x0100000000000800;
private readonly string CertStoreTitleMissingErrorMessage = "CertStore system title not found! SSL CA retrieving will not work, provide the system archive to fix this error. (See https://github.com/Ryujinx/Ryujinx/wiki/Ryujinx-Setup-&-Configuration-Guide#initial-setup-continued---installation-of-firmware for more information)";
private static BuiltInCertificateManager _instance;
public static BuiltInCertificateManager Instance
{
get
{
if (_instance == null)
{
_instance = new BuiltInCertificateManager();
}
return _instance;
}
}
private VirtualFileSystem _virtualFileSystem;
private IntegrityCheckLevel _fsIntegrityCheckLevel;
private ContentManager _contentManager;
private bool _initialized;
private Dictionary<CaCertificateId, CertStoreEntry> _certificates;
private object _lock = new object();
private struct CertStoreFileHeader
{
private const uint ValidMagic = 0x546C7373;
#pragma warning disable CS0649
public uint Magic;
public uint EntriesCount;
#pragma warning restore CS0649
public bool IsValid()
{
return Magic == ValidMagic;
}
}
private struct CertStoreFileEntry
{
#pragma warning disable CS0649
public CaCertificateId Id;
public TrustedCertStatus Status;
public uint DataSize;
public uint DataOffset;
#pragma warning restore CS0649
}
public class CertStoreEntry
{
public CaCertificateId Id;
public TrustedCertStatus Status;
public byte[] Data;
}
public string GetCertStoreTitleContentPath()
{
return _contentManager.GetInstalledContentPath(CertStoreTitleId, StorageId.NandSystem, NcaContentType.Data);
}
public bool HasCertStoreTitle()
{
return !string.IsNullOrEmpty(GetCertStoreTitleContentPath());
}
private CertStoreEntry ReadCertStoreEntry(ReadOnlySpan<byte> buffer, CertStoreFileEntry entry)
{
string customCertificatePath = System.IO.Path.Join(AppDataManager.BaseDirPath, "system", "ssl", $"{entry.Id}.der");
byte[] data;
if (File.Exists(customCertificatePath))
{
data = File.ReadAllBytes(customCertificatePath);
}
else
{
data = buffer.Slice((int)entry.DataOffset, (int)entry.DataSize).ToArray();
}
return new CertStoreEntry
{
Id = entry.Id,
Status = entry.Status,
Data = data
};
}
public void Initialize(Switch device)
{
lock (_lock)
{
_certificates = new Dictionary<CaCertificateId, CertStoreEntry>();
_initialized = false;
_contentManager = device.System.ContentManager;
_virtualFileSystem = device.FileSystem;
_fsIntegrityCheckLevel = device.System.FsIntegrityCheckLevel;
if (HasCertStoreTitle())
{
using LocalStorage ncaFile = new LocalStorage(_virtualFileSystem.SwitchPathToSystemPath(GetCertStoreTitleContentPath()), FileAccess.Read, FileMode.Open);
Nca nca = new Nca(_virtualFileSystem.KeySet, ncaFile);
IFileSystem romfs = nca.OpenFileSystem(NcaSectionType.Data, _fsIntegrityCheckLevel);
using var trustedCertsFileRef = new UniqueRef<IFile>();
Result result = romfs.OpenFile(ref trustedCertsFileRef.Ref(), "/ssl_TrustedCerts.bdf".ToU8Span(), OpenMode.Read);
if (!result.IsSuccess())
{
// [1.0.0 - 2.3.0]
if (ResultFs.PathNotFound.Includes(result))
{
result = romfs.OpenFile(ref trustedCertsFileRef.Ref(), "/ssl_TrustedCerts.tcf".ToU8Span(), OpenMode.Read);
}
if (result.IsFailure())
{
Logger.Error?.Print(LogClass.ServiceSsl, CertStoreTitleMissingErrorMessage);
return;
}
}
using IFile trustedCertsFile = trustedCertsFileRef.Release();
trustedCertsFile.GetSize(out long fileSize).ThrowIfFailure();
Span<byte> trustedCertsRaw = new byte[fileSize];
trustedCertsFile.Read(out _, 0, trustedCertsRaw).ThrowIfFailure();
CertStoreFileHeader header = MemoryMarshal.Read<CertStoreFileHeader>(trustedCertsRaw);
if (!header.IsValid())
{
Logger.Error?.Print(LogClass.ServiceSsl, "Invalid CertStore data found, skipping!");
return;
}
ReadOnlySpan<byte> trustedCertsData = trustedCertsRaw[Unsafe.SizeOf<CertStoreFileHeader>()..];
ReadOnlySpan<CertStoreFileEntry> trustedCertsEntries = MemoryMarshal.Cast<byte, CertStoreFileEntry>(trustedCertsData)[..(int)header.EntriesCount];
foreach (CertStoreFileEntry entry in trustedCertsEntries)
{
_certificates.Add(entry.Id, ReadCertStoreEntry(trustedCertsData, entry));
}
_initialized = true;
}
}
}
public bool TryGetCertificates(ReadOnlySpan<CaCertificateId> ids, out CertStoreEntry[] entries)
{
lock (_lock)
{
if (!_initialized)
{
throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage);
}
bool hasAllCertificates = false;
foreach (CaCertificateId id in ids)
{
if (id == CaCertificateId.All)
{
hasAllCertificates = true;
break;
}
}
if (hasAllCertificates)
{
entries = new CertStoreEntry[_certificates.Count];
int i = 0;
foreach (CertStoreEntry entry in _certificates.Values)
{
entries[i++] = entry;
}
return true;
}
else
{
entries = new CertStoreEntry[ids.Length];
for (int i = 0; i < ids.Length; i++)
{
if (!_certificates.TryGetValue(ids[i], out CertStoreEntry entry))
{
return false;
}
entries[i] = entry;
}
return true;
}
}
}
}
}

View file

@ -1,6 +1,11 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.Exceptions;
using Ryujinx.HLE.HOS.Services.Ssl.SslService; using Ryujinx.HLE.HOS.Services.Ssl.SslService;
using Ryujinx.HLE.HOS.Services.Ssl.Types; using Ryujinx.HLE.HOS.Services.Ssl.Types;
using Ryujinx.Memory;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Ssl namespace Ryujinx.HLE.HOS.Services.Ssl
{ {
@ -18,13 +23,85 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
SslVersion sslVersion = (SslVersion)context.RequestData.ReadUInt32(); SslVersion sslVersion = (SslVersion)context.RequestData.ReadUInt32();
ulong pidPlaceholder = context.RequestData.ReadUInt64(); ulong pidPlaceholder = context.RequestData.ReadUInt64();
MakeObject(context, new ISslContext(context)); MakeObject(context, new ISslContext(context.Request.HandleDesc.PId, sslVersion));
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sslVersion }); Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sslVersion });
return ResultCode.Success; return ResultCode.Success;
} }
private uint ComputeCertificateBufferSizeRequired(ReadOnlySpan<BuiltInCertificateManager.CertStoreEntry> entries)
{
uint totalSize = 0;
for (int i = 0; i < entries.Length; i++)
{
totalSize += (uint)Unsafe.SizeOf<BuiltInCertificateInfo>();
totalSize += (uint)entries[i].Data.Length;
}
return totalSize;
}
[CommandHipc(2)]
// GetCertificates(buffer<CaCertificateId, 5> ids) -> (u32 certificates_count, buffer<bytes, 6> certificates)
public ResultCode GetCertificates(ServiceCtx context)
{
ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries))
{
throw new InvalidOperationException();
}
if (ComputeCertificateBufferSizeRequired(entries) > context.Request.ReceiveBuff[0].Size)
{
return ResultCode.InvalidCertBufSize;
}
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
Span<byte> rawData = region.Memory.Span;
Span<BuiltInCertificateInfo> infos = MemoryMarshal.Cast<byte, BuiltInCertificateInfo>(rawData)[..entries.Length];
Span<byte> certificatesData = rawData[(Unsafe.SizeOf<BuiltInCertificateInfo>() * entries.Length)..];
for (int i = 0; i < infos.Length; i++)
{
entries[i].Data.CopyTo(certificatesData);
infos[i] = new BuiltInCertificateInfo
{
Id = entries[i].Id,
Status = entries[i].Status,
CertificateDataSize = (ulong)entries[i].Data.Length,
CertificateDataOffset = (ulong)(rawData.Length - certificatesData.Length)
};
certificatesData = certificatesData[entries[i].Data.Length..];
}
}
context.ResponseData.Write(entries.Length);
return ResultCode.Success;
}
[CommandHipc(3)]
// GetCertificateBufSize(buffer<CaCertificateId, 5> ids) -> u32 buffer_size;
public ResultCode GetCertificateBufSize(ServiceCtx context)
{
ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries))
{
throw new InvalidOperationException();
}
context.ResponseData.Write(ComputeCertificateBufferSizeRequired(entries));
return ResultCode.Success;
}
[CommandHipc(5)] [CommandHipc(5)]
// SetInterfaceVersion(u32) // SetInterfaceVersion(u32)
public ResultCode SetInterfaceVersion(ServiceCtx context) public ResultCode SetInterfaceVersion(ServiceCtx context)

View file

@ -0,0 +1,20 @@
namespace Ryujinx.HLE.HOS.Services.Ssl
{
public enum ResultCode
{
OsModuleId = 123,
ErrorCodeShift = 9,
Success = 0,
NoSocket = (103 << ErrorCodeShift) | OsModuleId,
InvalidSocket = (106 << ErrorCodeShift) | OsModuleId,
InvalidCertBufSize = (112 << ErrorCodeShift) | OsModuleId,
InvalidOption = (126 << ErrorCodeShift) | OsModuleId,
CertBufferTooSmall = (202 << ErrorCodeShift) | OsModuleId,
AlreadyInUse = (203 << ErrorCodeShift) | OsModuleId,
WouldBlock = (204 << ErrorCodeShift) | OsModuleId,
Timeout = (205 << ErrorCodeShift) | OsModuleId,
ConnectionReset = (209 << ErrorCodeShift) | OsModuleId,
ConnectionAbort = (210 << ErrorCodeShift) | OsModuleId
}
}

View file

@ -1,41 +1,101 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.Exceptions;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types; using Ryujinx.HLE.HOS.Services.Ssl.Types;
using Ryujinx.Memory;
using System;
using System.Text; using System.Text;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{ {
class ISslConnection : IpcService class ISslConnection : IpcService, IDisposable
{ {
public ISslConnection() { } private bool _doNotClockSocket;
private bool _getServerCertChain;
private bool _skipDefaultVerify;
private bool _enableAlpn;
private SslVersion _sslVersion;
private IoMode _ioMode;
private VerifyOption _verifyOption;
private SessionCacheMode _sessionCacheMode;
private string _hostName;
private ISslConnectionBase _connection;
private BsdContext _bsdContext;
private readonly long _processId;
private byte[] _nextAplnProto;
public ISslConnection(long processId, SslVersion sslVersion)
{
_processId = processId;
_sslVersion = sslVersion;
_ioMode = IoMode.Blocking;
_sessionCacheMode = SessionCacheMode.None;
_verifyOption = VerifyOption.PeerCa | VerifyOption.HostName;
}
[CommandHipc(0)] [CommandHipc(0)]
// SetSocketDescriptor(u32) -> u32 // SetSocketDescriptor(u32) -> u32
public ResultCode SetSocketDescriptor(ServiceCtx context) public ResultCode SetSocketDescriptor(ServiceCtx context)
{ {
uint socketFd = context.RequestData.ReadUInt32(); if (_connection != null)
uint duplicateSocketFd = 0; {
return ResultCode.AlreadyInUse;
}
context.ResponseData.Write(duplicateSocketFd); _bsdContext = BsdContext.GetContext(_processId);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { socketFd }); if (_bsdContext == null)
{
return ResultCode.InvalidSocket;
}
int inputFd = context.RequestData.ReadInt32();
int internalFd = _bsdContext.DuplicateFileDescriptor(inputFd);
if (internalFd == -1)
{
return ResultCode.InvalidSocket;
}
InitializeConnection(internalFd);
int outputFd = inputFd;
if (_doNotClockSocket)
{
outputFd = -1;
}
context.ResponseData.Write(outputFd);
return ResultCode.Success; return ResultCode.Success;
} }
private void InitializeConnection(int socketFd)
{
ISocket bsdSocket = _bsdContext.RetrieveSocket(socketFd);
_connection = new SslManagedSocketConnection(_bsdContext, _sslVersion, socketFd, bsdSocket);
}
[CommandHipc(1)] [CommandHipc(1)]
// SetHostName(buffer<bytes, 5>) // SetHostName(buffer<bytes, 5>)
public ResultCode SetHostName(ServiceCtx context) public ResultCode SetHostName(ServiceCtx context)
{ {
ulong hostNameDataPosition = context.Request.SendBuff[0].Position; ulong hostNameDataPosition = context.Request.SendBuff[0].Position;
ulong hostNameDataSize = context.Request.SendBuff[0].Size; ulong hostNameDataSize = context.Request.SendBuff[0].Size;
byte[] hostNameData = new byte[hostNameDataSize]; byte[] hostNameData = new byte[hostNameDataSize];
context.Memory.Read(hostNameDataPosition, hostNameData); context.Memory.Read(hostNameDataPosition, hostNameData);
string hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0'); _hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0');
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { hostName }); Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
return ResultCode.Success; return ResultCode.Success;
} }
@ -44,9 +104,9 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// SetVerifyOption(nn::ssl::sf::VerifyOption) // SetVerifyOption(nn::ssl::sf::VerifyOption)
public ResultCode SetVerifyOption(ServiceCtx context) public ResultCode SetVerifyOption(ServiceCtx context)
{ {
VerifyOption verifyOption = (VerifyOption)context.RequestData.ReadUInt32(); _verifyOption = (VerifyOption)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { verifyOption }); Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
return ResultCode.Success; return ResultCode.Success;
} }
@ -55,9 +115,67 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// SetIoMode(nn::ssl::sf::IoMode) // SetIoMode(nn::ssl::sf::IoMode)
public ResultCode SetIoMode(ServiceCtx context) public ResultCode SetIoMode(ServiceCtx context)
{ {
IoMode ioMode = (IoMode)context.RequestData.ReadUInt32(); if (_connection == null)
{
return ResultCode.NoSocket;
}
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { ioMode }); _ioMode = (IoMode)context.RequestData.ReadUInt32();
_connection.Socket.Blocking = _ioMode == IoMode.Blocking;
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
return ResultCode.Success;
}
[CommandHipc(4)]
// GetSocketDescriptor() -> u32
public ResultCode GetSocketDescriptor(ServiceCtx context)
{
context.ResponseData.Write(_connection.SocketFd);
return ResultCode.Success;
}
[CommandHipc(5)]
// GetHostName(buffer<bytes, 6>) -> u32
public ResultCode GetHostName(ServiceCtx context)
{
ulong hostNameDataPosition = context.Request.ReceiveBuff[0].Position;
ulong hostNameDataSize = context.Request.ReceiveBuff[0].Size;
byte[] hostNameData = new byte[hostNameDataSize];
Encoding.ASCII.GetBytes(_hostName, hostNameData);
context.Memory.Write(hostNameDataPosition, hostNameData);
context.ResponseData.Write((uint)_hostName.Length);
Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
return ResultCode.Success;
}
[CommandHipc(6)]
// GetVerifyOption() -> nn::ssl::sf::VerifyOption
public ResultCode GetVerifyOption(ServiceCtx context)
{
context.ResponseData.Write((uint)_verifyOption);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
return ResultCode.Success;
}
[CommandHipc(7)]
// GetIoMode() -> nn::ssl::sf::IoMode
public ResultCode GetIoMode(ServiceCtx context)
{
context.ResponseData.Write((uint)_ioMode);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
return ResultCode.Success; return ResultCode.Success;
} }
@ -66,30 +184,153 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// DoHandshake() // DoHandshake()
public ResultCode DoHandshake(ServiceCtx context) public ResultCode DoHandshake(ServiceCtx context)
{ {
Logger.Stub?.PrintStub(LogClass.ServiceSsl); if (_connection == null)
{
return ResultCode.NoSocket;
}
return _connection.Handshake(_hostName);
}
[CommandHipc(9)]
// DoHandshakeGetServerCert() -> (u32, u32, buffer<bytes, 6>)
public ResultCode DoHandshakeGetServerCert(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
ResultCode result = _connection.Handshake(_hostName);
if (result == ResultCode.Success)
{
if (_getServerCertChain)
{
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
result = _connection.GetServerCertificate(_hostName, region.Memory.Span, out uint bufferSize, out uint certificateCount);
context.ResponseData.Write(bufferSize);
context.ResponseData.Write(certificateCount);
}
}
else
{
context.ResponseData.Write(0);
context.ResponseData.Write(0);
}
}
return result;
}
[CommandHipc(10)]
// Read() -> (u32, buffer<bytes, 6>)
public ResultCode Read(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
ResultCode result;
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
// TODO: Better error management.
result = _connection.Read(out int readCount, region.Memory);
if (result == ResultCode.Success)
{
context.ResponseData.Write(readCount);
}
}
return result;
}
[CommandHipc(11)]
// Write(buffer<bytes, 5>) -> s32
public ResultCode Write(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
// We don't dispose as this isn't supposed to be modified
WritableRegion region = context.Memory.GetWritableRegion(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size);
// TODO: Better error management.
ResultCode result = _connection.Write(out int writtenCount, region.Memory);
if (result == ResultCode.Success)
{
context.ResponseData.Write(writtenCount);
}
return result;
}
[CommandHipc(12)]
// Pending() -> s32
public ResultCode Pending(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
context.ResponseData.Write(_connection.Pending());
return ResultCode.Success; return ResultCode.Success;
} }
[CommandHipc(11)] [CommandHipc(13)]
// Write(buffer<bytes, 5>) -> u32 // Peek() -> (s32, buffer<bytes, 6>)
public ResultCode Write(ServiceCtx context) public ResultCode Peek(ServiceCtx context)
{ {
ulong inputDataPosition = context.Request.SendBuff[0].Position; if (_connection == null)
ulong inputDataSize = context.Request.SendBuff[0].Size; {
return ResultCode.NoSocket;
}
byte[] data = new byte[inputDataSize]; ResultCode result;
context.Memory.Read(inputDataPosition, data); using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
// TODO: Better error management.
result = _connection.Peek(out int peekCount, region.Memory);
// NOTE: Tell the guest everything is transferred. if (result == ResultCode.Success)
uint transferredSize = (uint)inputDataSize; {
context.ResponseData.Write(peekCount);
}
}
context.ResponseData.Write(transferredSize); return result;
}
Logger.Stub?.PrintStub(LogClass.ServiceSsl); [CommandHipc(14)]
// Poll(nn::ssl::sf::PollEvent poll_event, u32 timeout) -> nn::ssl::sf::PollEvent
public ResultCode Poll(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
return ResultCode.Success; [CommandHipc(15)]
// GetVerifyCertError()
public ResultCode GetVerifyCertError(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(16)]
// GetNeededServerCertBufferSize() -> u32
public ResultCode GetNeededServerCertBufferSize(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
} }
[CommandHipc(17)] [CommandHipc(17)]
@ -100,19 +341,176 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode }); Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode });
_sessionCacheMode = sessionCacheMode;
return ResultCode.Success; return ResultCode.Success;
} }
[CommandHipc(18)]
// GetSessionCacheMode() -> nn::ssl::sf::SessionCacheMode
public ResultCode GetSessionCacheMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(19)]
// FlushSessionCache()
public ResultCode FlushSessionCache(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(20)]
// SetRenegotiationMode(nn::ssl::sf::RenegotiationMode)
public ResultCode SetRenegotiationMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(21)]
// GetRenegotiationMode() -> nn::ssl::sf::RenegotiationMode
public ResultCode GetRenegotiationMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(22)] [CommandHipc(22)]
// SetOption(b8, nn::ssl::sf::OptionType) // SetOption(b8 value, nn::ssl::sf::OptionType option)
public ResultCode SetOption(ServiceCtx context) public ResultCode SetOption(ServiceCtx context)
{ {
bool optionEnabled = context.RequestData.ReadBoolean(); bool value = context.RequestData.ReadUInt32() != 0;
OptionType optionType = (OptionType)context.RequestData.ReadUInt32(); OptionType option = (OptionType)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { optionType, optionEnabled }); Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option, value });
return SetOption(option, value);
}
[CommandHipc(23)]
// GetOption(nn::ssl::sf::OptionType) -> b8
public ResultCode GetOption(ServiceCtx context)
{
OptionType option = (OptionType)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option });
ResultCode result = GetOption(option, out bool value);
if (result == ResultCode.Success)
{
context.ResponseData.Write(value);
}
return result;
}
[CommandHipc(24)]
// GetVerifyCertErrors() -> (u32, u32, buffer<bytes, 6>)
public ResultCode GetVerifyCertErrors(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(25)] // 4.0.0+
// GetCipherInfo(u32) -> buffer<bytes, 6>
public ResultCode GetCipherInfo(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(26)]
// SetNextAlpnProto(buffer<bytes, 5>) -> u32
public ResultCode SetNextAlpnProto(ServiceCtx context)
{
ulong inputDataPosition = context.Request.SendBuff[0].Position;
ulong inputDataSize = context.Request.SendBuff[0].Size;
_nextAplnProto = new byte[inputDataSize];
context.Memory.Read(inputDataPosition, _nextAplnProto);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { inputDataSize });
return ResultCode.Success; return ResultCode.Success;
} }
[CommandHipc(27)]
// GetNextAlpnProto(buffer<bytes, 6>) -> u32
public ResultCode GetNextAlpnProto(ServiceCtx context)
{
ulong outputDataPosition = context.Request.ReceiveBuff[0].Position;
ulong outputDataSize = context.Request.ReceiveBuff[0].Size;
context.Memory.Write(outputDataPosition, _nextAplnProto);
context.ResponseData.Write(_nextAplnProto.Length);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { outputDataSize });
return ResultCode.Success;
}
private ResultCode SetOption(OptionType option, bool value)
{
switch (option)
{
case OptionType.DoNotCloseSocket:
_doNotClockSocket = value;
break;
case OptionType.GetServerCertChain:
_getServerCertChain = value;
break;
case OptionType.SkipDefaultVerify:
_skipDefaultVerify = value;
break;
case OptionType.EnableAlpn:
_enableAlpn = value;
break;
default:
Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
return ResultCode.InvalidOption;
}
return ResultCode.Success;
}
private ResultCode GetOption(OptionType option, out bool value)
{
switch (option)
{
case OptionType.DoNotCloseSocket:
value = _doNotClockSocket;
break;
case OptionType.GetServerCertChain:
value = _getServerCertChain;
break;
case OptionType.SkipDefaultVerify:
value = _skipDefaultVerify;
break;
case OptionType.EnableAlpn:
value = _enableAlpn;
break;
default:
Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
value = false;
return ResultCode.InvalidOption;
}
return ResultCode.Success;
}
public void Dispose()
{
_connection?.Dispose();
}
} }
} }

View file

@ -0,0 +1,25 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using System;
using System.Net.Sockets;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
interface ISslConnectionBase: IDisposable
{
int SocketFd { get; }
ISocket Socket { get; }
ResultCode Handshake(string hostName);
ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount);
ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer);
ResultCode Read(out int readCount, Memory<byte> buffer);
ResultCode Peek(out int peekCount, Memory<byte> buffer);
int Pending();
}
}

View file

@ -1,4 +1,5 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types; using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System.Text; using System.Text;
@ -8,16 +9,22 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{ {
private uint _connectionCount; private uint _connectionCount;
private readonly long _processId;
private readonly SslVersion _sslVersion;
private ulong _serverCertificateId; private ulong _serverCertificateId;
private ulong _clientCertificateId; private ulong _clientCertificateId;
public ISslContext(ServiceCtx context) { } public ISslContext(long processId, SslVersion sslVersion)
{
_processId = processId;
_sslVersion = sslVersion;
}
[CommandHipc(2)] [CommandHipc(2)]
// CreateConnection() -> object<nn::ssl::sf::ISslConnection> // CreateConnection() -> object<nn::ssl::sf::ISslConnection>
public ResultCode CreateConnection(ServiceCtx context) public ResultCode CreateConnection(ServiceCtx context)
{ {
MakeObject(context, new ISslConnection()); MakeObject(context, new ISslConnection(_processId, _sslVersion));
_connectionCount++; _connectionCount++;

View file

@ -0,0 +1,247 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.IO;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
class SslManagedSocketConnection : ISslConnectionBase
{
public int SocketFd { get; }
public ISocket Socket { get; }
private BsdContext _bsdContext;
private SslVersion _sslVersion;
private SslStream _stream;
private bool _isBlockingSocket;
private int _previousReadTimeout;
public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
{
_bsdContext = bsdContext;
_sslVersion = sslVersion;
SocketFd = socketFd;
Socket = socket;
}
private void StartSslOperation()
{
// Save blocking state
_isBlockingSocket = Socket.Blocking;
// Force blocking for SslStream
Socket.Blocking = true;
}
private void EndSslOperation()
{
// Restore blocking state
Socket.Blocking = _isBlockingSocket;
}
private void StartSslReadOperation()
{
StartSslOperation();
if (!_isBlockingSocket)
{
_previousReadTimeout = _stream.ReadTimeout;
_stream.ReadTimeout = 1;
}
}
private void EndSslReadOperation()
{
if (!_isBlockingSocket)
{
_stream.ReadTimeout = _previousReadTimeout;
}
EndSslOperation();
}
private static SslProtocols TranslateSslVersion(SslVersion version)
{
switch (version & SslVersion.VersionMask)
{
case SslVersion.Auto:
return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
case SslVersion.TlsV10:
return SslProtocols.Tls;
case SslVersion.TlsV11:
return SslProtocols.Tls11;
case SslVersion.TlsV12:
return SslProtocols.Tls12;
case SslVersion.TlsV13:
return SslProtocols.Tls13;
default:
throw new NotImplementedException(version.ToString());
}
}
public ResultCode Handshake(string hostName)
{
StartSslOperation();
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();
return ResultCode.Success;
}
public ResultCode Peek(out int peekCount, Memory<byte> buffer)
{
// NOTE: We cannot support that on .NET SSL API.
// As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
peekCount = -1;
return ResultCode.WouldBlock;
}
public int Pending()
{
// Unsupported
return 0;
}
private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
{
switch (error)
{
case WsaError.WSAETIMEDOUT:
resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
return true;
case WsaError.WSAECONNABORTED:
resultCode = ResultCode.ConnectionAbort;
return true;
case WsaError.WSAECONNRESET:
resultCode = ResultCode.ConnectionReset;
return true;
default:
resultCode = ResultCode.Success;
return false;
}
}
public ResultCode Read(out int readCount, Memory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectRead))
{
readCount = -1;
return ResultCode.WouldBlock;
}
StartSslReadOperation();
try
{
readCount = _stream.Read(buffer.Span);
}
catch (IOException exception)
{
readCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslReadOperation();
}
return ResultCode.Success;
}
public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectWrite))
{
writtenCount = 0;
return ResultCode.WouldBlock;
}
StartSslOperation();
try
{
_stream.Write(buffer.Span);
}
catch (IOException exception)
{
writtenCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslOperation();
}
// .NET API doesn't provide the size written, assume all written.
writtenCount = buffer.Length;
return ResultCode.Success;
}
public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
{
byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
storageSize = (uint)rawCertData.Length;
certificateCount = 1;
if (rawCertData.Length > certificates.Length)
{
return ResultCode.CertBufferTooSmall;
}
rawCertData.CopyTo(certificates);
return ResultCode.Success;
}
public void Dispose()
{
_bsdContext.CloseFileDescriptor(SocketFd);
}
}
}

View file

@ -0,0 +1,10 @@
namespace Ryujinx.HLE.HOS.Services.Ssl.Types
{
struct BuiltInCertificateInfo
{
public CaCertificateId Id;
public TrustedCertStatus Status;
public ulong CertificateDataSize;
public ulong CertificateDataOffset;
}
}

View file

@ -0,0 +1,68 @@
namespace Ryujinx.HLE.HOS.Services.Ssl.Types
{
enum CaCertificateId : uint
{
// Nintendo CAs
NintendoCAG3 = 1,
NintendoClass2CAG3,
// External CAs
AmazonRootCA1 = 1000,
StarfieldServicesRootCertificateAuthorityG2,
AddTrustExternalCARoot,
COMODOCertificationAuthority,
UTNDATACorpSGC,
UTNUSERFirstHardware,
BaltimoreCyberTrustRoot,
CybertrustGlobalRoot,
VerizonGlobalRootCA,
DigiCertAssuredIDRootCA,
DigiCertAssuredIDRootG2,
DigiCertGlobalRootCA,
DigiCertGlobalRootG2,
DigiCertHighAssuranceEVRootCA,
EntrustnetCertificationAuthority2048,
EntrustRootCertificationAuthority,
EntrustRootCertificationAuthorityG2,
GeoTrustGlobalCA2,
GeoTrustGlobalCA,
GeoTrustPrimaryCertificationAuthorityG3,
GeoTrustPrimaryCertificationAuthority,
GlobalSignRootCA,
GlobalSignRootCAR2,
GlobalSignRootCAR3,
GoDaddyClass2CertificationAuthority,
GoDaddyRootCertificateAuthorityG2,
StarfieldClass2CertificationAuthority,
StarfieldRootCertificateAuthorityG2,
ThawtePrimaryRootCAG3,
ThawtePrimaryRootCA,
VeriSignClass3PublicPrimaryCertificationAuthorityG3,
VeriSignClass3PublicPrimaryCertificationAuthorityG5,
VeriSignUniversalRootCertificationAuthority,
DSTRootCAX3,
USERTrustRSACertificationAuthority,
ISRGRootX10,
USERTrustECCCertificationAuthority,
COMODORSACertificationAuthority,
COMODOECCCertificationAuthority,
AmazonRootCA2,
AmazonRootCA3,
AmazonRootCA4,
DigiCertAssuredIDRootG3,
DigiCertGlobalRootG3,
DigiCertTrustedRootG4,
EntrustRootCertificationAuthorityEC1,
EntrustRootCertificationAuthorityG4,
GlobalSignECCRootCAR4,
GlobalSignECCRootCAR5,
GlobalSignECCRootCAR6,
GTSRootR1,
GTSRootR2,
GTSRootR3,
GTSRootR4,
SecurityCommunicationRootCA,
All = uint.MaxValue
}
}

View file

@ -10,6 +10,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.Types
TlsV11 = 1 << 4, TlsV11 = 1 << 4,
TlsV12 = 1 << 5, TlsV12 = 1 << 5,
TlsV13 = 1 << 6, // 11.0.0+ TlsV13 = 1 << 6, // 11.0.0+
Auto2 = 1 << 24 // 11.0.0+
VersionMask = 0xFFFFFF
} }
} }

View file

@ -0,0 +1,12 @@
namespace Ryujinx.HLE.HOS.Services.Ssl.Types
{
enum TrustedCertStatus : uint
{
Removed,
EnabledTrusted,
EnabledNotTrusted,
Revoked,
Invalid = uint.MaxValue
}
}