Fix memory corruption in BCAT and FS Read methods when buffer is larger than needed (#3739)

* Fix memory corruption in FS Read methods when buffer is larger than needed

* PR feedback

* nit: Don't move this around
This commit is contained in:
gdkchan 2022-10-04 20:12:54 -03:00 committed by GitHub
parent 2068445939
commit 60e16c15b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 69 additions and 79 deletions

View file

@ -38,19 +38,18 @@ namespace Ryujinx.HLE.HOS.Services.Bcat.ServiceCreator
// Read() -> (u32, buffer<nn::bcat::DeliveryCacheDirectoryEntry, 6>) // Read() -> (u32, buffer<nn::bcat::DeliveryCacheDirectoryEntry, 6>)
public ResultCode Read(ServiceCtx context) public ResultCode Read(ServiceCtx context)
{ {
ulong position = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong size = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] data = new byte[size]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _base.Get.Read(out int entriesRead, MemoryMarshal.Cast<byte, DeliveryCacheDirectoryEntry>(data)); Result result = _base.Get.Read(out int entriesRead, MemoryMarshal.Cast<byte, DeliveryCacheDirectoryEntry>(region.Memory.Span));
context.Memory.Write(position, data);
context.ResponseData.Write(entriesRead); context.ResponseData.Write(entriesRead);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
[CommandHipc(2)] [CommandHipc(2)]
// GetCount() -> u32 // GetCount() -> u32

View file

@ -38,21 +38,20 @@ namespace Ryujinx.HLE.HOS.Services.Bcat.ServiceCreator
// Read(u64) -> (u64, buffer<bytes, 6>) // Read(u64) -> (u64, buffer<bytes, 6>)
public ResultCode Read(ServiceCtx context) public ResultCode Read(ServiceCtx context)
{ {
ulong position = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong size = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
long offset = context.RequestData.ReadInt64(); long offset = context.RequestData.ReadInt64();
byte[] data = new byte[size]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _base.Get.Read(out long bytesRead, offset, data); Result result = _base.Get.Read(out long bytesRead, offset, region.Memory.Span);
context.Memory.Write(position, data);
context.ResponseData.Write(bytesRead); context.ResponseData.Write(bytesRead);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
[CommandHipc(2)] [CommandHipc(2)]
// GetSize() -> u64 // GetSize() -> u64

View file

@ -50,19 +50,18 @@ namespace Ryujinx.HLE.HOS.Services.Bcat.ServiceCreator
// EnumerateDeliveryCacheDirectory() -> (u32, buffer<nn::bcat::DirectoryName, 6>) // EnumerateDeliveryCacheDirectory() -> (u32, buffer<nn::bcat::DirectoryName, 6>)
public ResultCode EnumerateDeliveryCacheDirectory(ServiceCtx context) public ResultCode EnumerateDeliveryCacheDirectory(ServiceCtx context)
{ {
ulong position = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong size = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] data = new byte[size]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _base.Get.EnumerateDeliveryCacheDirectory(out int count, MemoryMarshal.Cast<byte, DirectoryName>(data)); Result result = _base.Get.EnumerateDeliveryCacheDirectory(out int count, MemoryMarshal.Cast<byte, DirectoryName>(region.Memory.Span));
context.Memory.Write(position, data);
context.ResponseData.Write(count); context.ResponseData.Write(count);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
protected override void Dispose(bool isDisposing) protected override void Dispose(bool isDisposing)
{ {

View file

@ -17,18 +17,18 @@ namespace Ryujinx.HLE.HOS.Services.Fs.FileSystemProxy
// Read() -> (u64 count, buffer<nn::fssrv::sf::IDirectoryEntry, 6, 0> entries) // Read() -> (u64 count, buffer<nn::fssrv::sf::IDirectoryEntry, 6, 0> entries)
public ResultCode Read(ServiceCtx context) public ResultCode Read(ServiceCtx context)
{ {
ulong bufferPosition = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] entryBuffer = new byte[bufferLen]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _baseDirectory.Get.Read(out long entriesRead, new OutBuffer(region.Memory.Span));
Result result = _baseDirectory.Get.Read(out long entriesRead, new OutBuffer(entryBuffer));
context.Memory.Write(bufferPosition, entryBuffer);
context.ResponseData.Write(entriesRead); context.ResponseData.Write(entriesRead);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
[CommandHipc(1)] [CommandHipc(1)]
// GetEntryCount() -> u64 // GetEntryCount() -> u64

View file

@ -19,7 +19,8 @@ namespace Ryujinx.HLE.HOS.Services.Fs.FileSystemProxy
// Read(u32 readOption, u64 offset, u64 size) -> (u64 out_size, buffer<u8, 0x46, 0> out_buf) // Read(u32 readOption, u64 offset, u64 size) -> (u64 out_size, buffer<u8, 0x46, 0> out_buf)
public ResultCode Read(ServiceCtx context) public ResultCode Read(ServiceCtx context)
{ {
ulong position = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size;
ReadOption readOption = context.RequestData.ReadStruct<ReadOption>(); ReadOption readOption = context.RequestData.ReadStruct<ReadOption>();
context.RequestData.BaseStream.Position += 4; context.RequestData.BaseStream.Position += 4;
@ -27,16 +28,15 @@ namespace Ryujinx.HLE.HOS.Services.Fs.FileSystemProxy
long offset = context.RequestData.ReadInt64(); long offset = context.RequestData.ReadInt64();
long size = context.RequestData.ReadInt64(); long size = context.RequestData.ReadInt64();
byte[] data = new byte[context.Request.ReceiveBuff[0].Size]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _baseFile.Get.Read(out long bytesRead, offset, new OutBuffer(data), size, readOption); Result result = _baseFile.Get.Read(out long bytesRead, offset, new OutBuffer(region.Memory.Span), size, readOption);
context.Memory.Write(position, data);
context.ResponseData.Write(bytesRead); context.ResponseData.Write(bytesRead);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
[CommandHipc(1)] [CommandHipc(1)]
// Write(u32 writeOption, u64 offset, u64 size, buffer<u8, 0x45, 0>) // Write(u32 writeOption, u64 offset, u64 size, buffer<u8, 0x45, 0>)

View file

@ -197,13 +197,7 @@ namespace Ryujinx.HLE.HOS.Services.Fs.FileSystemProxy
context.ResponseData.Write(timestamp.Created); context.ResponseData.Write(timestamp.Created);
context.ResponseData.Write(timestamp.Modified); context.ResponseData.Write(timestamp.Modified);
context.ResponseData.Write(timestamp.Accessed); context.ResponseData.Write(timestamp.Accessed);
context.ResponseData.Write(1L); // Is valid?
byte[] data = new byte[8];
// is valid?
data[0] = 1;
context.ResponseData.Write(data);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }

View file

@ -23,22 +23,22 @@ namespace Ryujinx.HLE.HOS.Services.Fs.FileSystemProxy
if (context.Request.ReceiveBuff.Count > 0) if (context.Request.ReceiveBuff.Count > 0)
{ {
IpcBuffDesc buffDesc = context.Request.ReceiveBuff[0]; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size;
// Use smaller length to avoid overflows. // Use smaller length to avoid overflows.
if (size > buffDesc.Size) if (size > bufferLen)
{ {
size = buffDesc.Size; size = bufferLen;
} }
byte[] data = new byte[size]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _baseStorage.Get.Read((long)offset, new OutBuffer(data), (long)size); Result result = _baseStorage.Get.Read((long)offset, new OutBuffer(region.Memory.Span), (long)size);
context.Memory.Write(buffDesc.Position, data);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
return ResultCode.Success; return ResultCode.Success;
} }

View file

@ -500,16 +500,16 @@ namespace Ryujinx.HLE.HOS.Services.Fs
SaveDataSpaceId spaceId = (SaveDataSpaceId)context.RequestData.ReadInt64(); SaveDataSpaceId spaceId = (SaveDataSpaceId)context.RequestData.ReadInt64();
SaveDataFilter filter = context.RequestData.ReadStruct<SaveDataFilter>(); SaveDataFilter filter = context.RequestData.ReadStruct<SaveDataFilter>();
ulong bufferPosition = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] infoBuffer = new byte[bufferLen]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _baseFileSystemProxy.Get.FindSaveDataWithFilter(out long count, new OutBuffer(infoBuffer), spaceId, in filter); Result result = _baseFileSystemProxy.Get.FindSaveDataWithFilter(out long count, new OutBuffer(region.Memory.Span), spaceId, in filter);
if (result.IsFailure()) return (ResultCode)result.Value; if (result.IsFailure()) return (ResultCode)result.Value;
context.Memory.Write(bufferPosition, infoBuffer);
context.ResponseData.Write(count); context.ResponseData.Write(count);
}
return ResultCode.Success; return ResultCode.Success;
} }

View file

@ -17,18 +17,18 @@ namespace Ryujinx.HLE.HOS.Services.Fs
// ReadSaveDataInfo() -> (u64, buffer<unknown, 6>) // ReadSaveDataInfo() -> (u64, buffer<unknown, 6>)
public ResultCode ReadSaveDataInfo(ServiceCtx context) public ResultCode ReadSaveDataInfo(ServiceCtx context)
{ {
ulong bufferPosition = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong bufferLen = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] infoBuffer = new byte[bufferLen]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Result result = _baseReader.Get.Read(out long readCount, new OutBuffer(region.Memory.Span));
Result result = _baseReader.Get.Read(out long readCount, new OutBuffer(infoBuffer));
context.Memory.Write(bufferPosition, infoBuffer);
context.ResponseData.Write(readCount); context.ResponseData.Write(readCount);
return (ResultCode)result.Value; return (ResultCode)result.Value;
} }
}
protected override void Dispose(bool isDisposing) protected override void Dispose(bool isDisposing)
{ {

View file

@ -142,14 +142,13 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// GetHostName(buffer<bytes, 6>) -> u32 // GetHostName(buffer<bytes, 6>) -> u32
public ResultCode GetHostName(ServiceCtx context) public ResultCode GetHostName(ServiceCtx context)
{ {
ulong hostNameDataPosition = context.Request.ReceiveBuff[0].Position; ulong bufferAddress = context.Request.ReceiveBuff[0].Position;
ulong hostNameDataSize = context.Request.ReceiveBuff[0].Size; ulong bufferLen = context.Request.ReceiveBuff[0].Size;
byte[] hostNameData = new byte[hostNameDataSize]; using (var region = context.Memory.GetWritableRegion(bufferAddress, (int)bufferLen, true))
{
Encoding.ASCII.GetBytes(_hostName, hostNameData); Encoding.ASCII.GetBytes(_hostName, region.Memory.Span);
}
context.Memory.Write(hostNameDataPosition, hostNameData);
context.ResponseData.Write((uint)_hostName.Length); context.ResponseData.Write((uint)_hostName.Length);