R/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs
gdkchan 60e16c15b6
Fix memory corruption in BCAT and FS Read methods when buffer is larger than needed (#3739)
* Fix memory corruption in FS Read methods when buffer is larger than needed

* PR feedback

* nit: Don't move this around
2022-10-04 20:12:54 -03:00

515 lines
No EOL
16 KiB
C#

using Ryujinx.Common.Logging;
using Ryujinx.HLE.Exceptions;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using Ryujinx.Memory;
using System;
using System.Text;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
class ISslConnection : IpcService, IDisposable
{
private bool _doNotClockSocket;
private bool _getServerCertChain;
private bool _skipDefaultVerify;
private bool _enableAlpn;
private SslVersion _sslVersion;
private IoMode _ioMode;
private VerifyOption _verifyOption;
private SessionCacheMode _sessionCacheMode;
private string _hostName;
private ISslConnectionBase _connection;
private BsdContext _bsdContext;
private readonly ulong _processId;
private byte[] _nextAplnProto;
public ISslConnection(ulong processId, SslVersion sslVersion)
{
_processId = processId;
_sslVersion = sslVersion;
_ioMode = IoMode.Blocking;
_sessionCacheMode = SessionCacheMode.None;
_verifyOption = VerifyOption.PeerCa | VerifyOption.HostName;
}
[CommandHipc(0)]
// SetSocketDescriptor(u32) -> u32
public ResultCode SetSocketDescriptor(ServiceCtx context)
{
if (_connection != null)
{
return ResultCode.AlreadyInUse;
}
_bsdContext = BsdContext.GetContext(_processId);
if (_bsdContext == null)
{
return ResultCode.InvalidSocket;
}
int inputFd = context.RequestData.ReadInt32();
int internalFd = _bsdContext.DuplicateFileDescriptor(inputFd);
if (internalFd == -1)
{
return ResultCode.InvalidSocket;
}
InitializeConnection(internalFd);
int outputFd = inputFd;
if (_doNotClockSocket)
{
outputFd = -1;
}
context.ResponseData.Write(outputFd);
return ResultCode.Success;
}
private void InitializeConnection(int socketFd)
{
ISocket bsdSocket = _bsdContext.RetrieveSocket(socketFd);
_connection = new SslManagedSocketConnection(_bsdContext, _sslVersion, socketFd, bsdSocket);
}
[CommandHipc(1)]
// SetHostName(buffer<bytes, 5>)
public ResultCode SetHostName(ServiceCtx context)
{
ulong hostNameDataPosition = context.Request.SendBuff[0].Position;
ulong hostNameDataSize = context.Request.SendBuff[0].Size;
byte[] hostNameData = new byte[hostNameDataSize];
context.Memory.Read(hostNameDataPosition, hostNameData);
_hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0');
Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
return ResultCode.Success;
}
[CommandHipc(2)]
// SetVerifyOption(nn::ssl::sf::VerifyOption)
public ResultCode SetVerifyOption(ServiceCtx context)
{
_verifyOption = (VerifyOption)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
return ResultCode.Success;
}
[CommandHipc(3)]
// SetIoMode(nn::ssl::sf::IoMode)
public ResultCode SetIoMode(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
_ioMode = (IoMode)context.RequestData.ReadUInt32();
_connection.Socket.Blocking = _ioMode == IoMode.Blocking;
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
return ResultCode.Success;
}
[CommandHipc(4)]
// GetSocketDescriptor() -> u32
public ResultCode GetSocketDescriptor(ServiceCtx context)
{
context.ResponseData.Write(_connection.SocketFd);
return ResultCode.Success;
}
[CommandHipc(5)]
// GetHostName(buffer<bytes, 6>) -> u32
public ResultCode GetHostName(ServiceCtx context)
{
ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size;
using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Encoding.ASCII.GetBytes(_hostName, region.Memory.Span);
}
context.ResponseData.Write((uint)_hostName.Length);
Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
return ResultCode.Success;
}
[CommandHipc(6)]
// GetVerifyOption() -> nn::ssl::sf::VerifyOption
public ResultCode GetVerifyOption(ServiceCtx context)
{
context.ResponseData.Write((uint)_verifyOption);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
return ResultCode.Success;
}
[CommandHipc(7)]
// GetIoMode() -> nn::ssl::sf::IoMode
public ResultCode GetIoMode(ServiceCtx context)
{
context.ResponseData.Write((uint)_ioMode);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
return ResultCode.Success;
}
[CommandHipc(8)]
// DoHandshake()
public ResultCode DoHandshake(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
return _connection.Handshake(_hostName);
}
[CommandHipc(9)]
// DoHandshakeGetServerCert() -> (u32, u32, buffer<bytes, 6>)
public ResultCode DoHandshakeGetServerCert(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
ResultCode result = _connection.Handshake(_hostName);
if (result == ResultCode.Success)
{
if (_getServerCertChain)
{
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
result = _connection.GetServerCertificate(_hostName, region.Memory.Span, out uint bufferSize, out uint certificateCount);
context.ResponseData.Write(bufferSize);
context.ResponseData.Write(certificateCount);
}
}
else
{
context.ResponseData.Write(0);
context.ResponseData.Write(0);
}
}
return result;
}
[CommandHipc(10)]
// Read() -> (u32, buffer<bytes, 6>)
public ResultCode Read(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
ResultCode result;
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
// TODO: Better error management.
result = _connection.Read(out int readCount, region.Memory);
if (result == ResultCode.Success)
{
context.ResponseData.Write(readCount);
}
}
return result;
}
[CommandHipc(11)]
// Write(buffer<bytes, 5>) -> s32
public ResultCode Write(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
// We don't dispose as this isn't supposed to be modified
WritableRegion region = context.Memory.GetWritableRegion(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size);
// TODO: Better error management.
ResultCode result = _connection.Write(out int writtenCount, region.Memory);
if (result == ResultCode.Success)
{
context.ResponseData.Write(writtenCount);
}
return result;
}
[CommandHipc(12)]
// Pending() -> s32
public ResultCode Pending(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
context.ResponseData.Write(_connection.Pending());
return ResultCode.Success;
}
[CommandHipc(13)]
// Peek() -> (s32, buffer<bytes, 6>)
public ResultCode Peek(ServiceCtx context)
{
if (_connection == null)
{
return ResultCode.NoSocket;
}
ResultCode result;
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{
// TODO: Better error management.
result = _connection.Peek(out int peekCount, region.Memory);
if (result == ResultCode.Success)
{
context.ResponseData.Write(peekCount);
}
}
return result;
}
[CommandHipc(14)]
// Poll(nn::ssl::sf::PollEvent poll_event, u32 timeout) -> nn::ssl::sf::PollEvent
public ResultCode Poll(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(15)]
// GetVerifyCertError()
public ResultCode GetVerifyCertError(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(16)]
// GetNeededServerCertBufferSize() -> u32
public ResultCode GetNeededServerCertBufferSize(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(17)]
// SetSessionCacheMode(nn::ssl::sf::SessionCacheMode)
public ResultCode SetSessionCacheMode(ServiceCtx context)
{
SessionCacheMode sessionCacheMode = (SessionCacheMode)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode });
_sessionCacheMode = sessionCacheMode;
return ResultCode.Success;
}
[CommandHipc(18)]
// GetSessionCacheMode() -> nn::ssl::sf::SessionCacheMode
public ResultCode GetSessionCacheMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(19)]
// FlushSessionCache()
public ResultCode FlushSessionCache(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(20)]
// SetRenegotiationMode(nn::ssl::sf::RenegotiationMode)
public ResultCode SetRenegotiationMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(21)]
// GetRenegotiationMode() -> nn::ssl::sf::RenegotiationMode
public ResultCode GetRenegotiationMode(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(22)]
// SetOption(b8 value, nn::ssl::sf::OptionType option)
public ResultCode SetOption(ServiceCtx context)
{
bool value = context.RequestData.ReadUInt32() != 0;
OptionType option = (OptionType)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option, value });
return SetOption(option, value);
}
[CommandHipc(23)]
// GetOption(nn::ssl::sf::OptionType) -> b8
public ResultCode GetOption(ServiceCtx context)
{
OptionType option = (OptionType)context.RequestData.ReadUInt32();
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option });
ResultCode result = GetOption(option, out bool value);
if (result == ResultCode.Success)
{
context.ResponseData.Write(value);
}
return result;
}
[CommandHipc(24)]
// GetVerifyCertErrors() -> (u32, u32, buffer<bytes, 6>)
public ResultCode GetVerifyCertErrors(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(25)] // 4.0.0+
// GetCipherInfo(u32) -> buffer<bytes, 6>
public ResultCode GetCipherInfo(ServiceCtx context)
{
throw new ServiceNotImplementedException(this, context);
}
[CommandHipc(26)]
// SetNextAlpnProto(buffer<bytes, 5>) -> u32
public ResultCode SetNextAlpnProto(ServiceCtx context)
{
ulong inputDataPosition = context.Request.SendBuff[0].Position;
ulong inputDataSize = context.Request.SendBuff[0].Size;
_nextAplnProto = new byte[inputDataSize];
context.Memory.Read(inputDataPosition, _nextAplnProto);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { inputDataSize });
return ResultCode.Success;
}
[CommandHipc(27)]
// GetNextAlpnProto(buffer<bytes, 6>) -> u32
public ResultCode GetNextAlpnProto(ServiceCtx context)
{
ulong outputDataPosition = context.Request.ReceiveBuff[0].Position;
ulong outputDataSize = context.Request.ReceiveBuff[0].Size;
context.Memory.Write(outputDataPosition, _nextAplnProto);
context.ResponseData.Write(_nextAplnProto.Length);
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { outputDataSize });
return ResultCode.Success;
}
private ResultCode SetOption(OptionType option, bool value)
{
switch (option)
{
case OptionType.DoNotCloseSocket:
_doNotClockSocket = value;
break;
case OptionType.GetServerCertChain:
_getServerCertChain = value;
break;
case OptionType.SkipDefaultVerify:
_skipDefaultVerify = value;
break;
case OptionType.EnableAlpn:
_enableAlpn = value;
break;
default:
Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
return ResultCode.InvalidOption;
}
return ResultCode.Success;
}
private ResultCode GetOption(OptionType option, out bool value)
{
switch (option)
{
case OptionType.DoNotCloseSocket:
value = _doNotClockSocket;
break;
case OptionType.GetServerCertChain:
value = _getServerCertChain;
break;
case OptionType.SkipDefaultVerify:
value = _skipDefaultVerify;
break;
case OptionType.EnableAlpn:
value = _enableAlpn;
break;
default:
Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
value = false;
return ResultCode.InvalidOption;
}
return ResultCode.Success;
}
public void Dispose()
{
_connection?.Dispose();
}
}
}