rjx-mirror/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
Mary 3fa7ef21b4
ssl: Implement SSL connectivity (#2961)
* implement certain servicessl functions

* ssl: Implement more of SSL connection and abstract it

This adds support to non blocking SSL operations and unlink the SSL
implementation from the IPC logic.

* Rename SslDefaultSocketConnection to SslManagedSocketConnection

* Fix regression on Pokemon TV

* Address gdkchan's comment

* Simplify value read from previous commit

* ssl: some changes

- Implement builtin certificates parsing and retrieving
- Fix issues with SSL version handling
- Improve managed SSL socket error handling
- Ensure to only return a certificate on DoHandshake when actually requested

* Add missing BuiltInCertificateManager initialization call

* Address gdkchan's comment

* Address Ack's comment

Co-authored-by: InvoxiPlayGames <webmaster@invoxiplaygames.uk>
2022-01-13 23:29:04 +01:00

247 lines
7.2 KiB
C#

using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.IO;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
class SslManagedSocketConnection : ISslConnectionBase
{
public int SocketFd { get; }
public ISocket Socket { get; }
private BsdContext _bsdContext;
private SslVersion _sslVersion;
private SslStream _stream;
private bool _isBlockingSocket;
private int _previousReadTimeout;
public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
{
_bsdContext = bsdContext;
_sslVersion = sslVersion;
SocketFd = socketFd;
Socket = socket;
}
private void StartSslOperation()
{
// Save blocking state
_isBlockingSocket = Socket.Blocking;
// Force blocking for SslStream
Socket.Blocking = true;
}
private void EndSslOperation()
{
// Restore blocking state
Socket.Blocking = _isBlockingSocket;
}
private void StartSslReadOperation()
{
StartSslOperation();
if (!_isBlockingSocket)
{
_previousReadTimeout = _stream.ReadTimeout;
_stream.ReadTimeout = 1;
}
}
private void EndSslReadOperation()
{
if (!_isBlockingSocket)
{
_stream.ReadTimeout = _previousReadTimeout;
}
EndSslOperation();
}
private static SslProtocols TranslateSslVersion(SslVersion version)
{
switch (version & SslVersion.VersionMask)
{
case SslVersion.Auto:
return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
case SslVersion.TlsV10:
return SslProtocols.Tls;
case SslVersion.TlsV11:
return SslProtocols.Tls11;
case SslVersion.TlsV12:
return SslProtocols.Tls12;
case SslVersion.TlsV13:
return SslProtocols.Tls13;
default:
throw new NotImplementedException(version.ToString());
}
}
public ResultCode Handshake(string hostName)
{
StartSslOperation();
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();
return ResultCode.Success;
}
public ResultCode Peek(out int peekCount, Memory<byte> buffer)
{
// NOTE: We cannot support that on .NET SSL API.
// As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
peekCount = -1;
return ResultCode.WouldBlock;
}
public int Pending()
{
// Unsupported
return 0;
}
private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
{
switch (error)
{
case WsaError.WSAETIMEDOUT:
resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
return true;
case WsaError.WSAECONNABORTED:
resultCode = ResultCode.ConnectionAbort;
return true;
case WsaError.WSAECONNRESET:
resultCode = ResultCode.ConnectionReset;
return true;
default:
resultCode = ResultCode.Success;
return false;
}
}
public ResultCode Read(out int readCount, Memory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectRead))
{
readCount = -1;
return ResultCode.WouldBlock;
}
StartSslReadOperation();
try
{
readCount = _stream.Read(buffer.Span);
}
catch (IOException exception)
{
readCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslReadOperation();
}
return ResultCode.Success;
}
public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectWrite))
{
writtenCount = 0;
return ResultCode.WouldBlock;
}
StartSslOperation();
try
{
_stream.Write(buffer.Span);
}
catch (IOException exception)
{
writtenCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslOperation();
}
// .NET API doesn't provide the size written, assume all written.
writtenCount = buffer.Length;
return ResultCode.Success;
}
public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
{
byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
storageSize = (uint)rawCertData.Length;
certificateCount = 1;
if (rawCertData.Length > certificates.Length)
{
return ResultCode.CertBufferTooSmall;
}
rawCertData.CopyTo(certificates);
return ResultCode.Success;
}
public void Dispose()
{
_bsdContext.CloseFileDescriptor(SocketFd);
}
}
}