From 715b605e9541cd5a7e4cce7609d96dbc41cd0326 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Tue, 16 Feb 2021 15:04:19 -0300
Subject: [PATCH] Validate CPU virtual addresses on access (#1987)

* Enable PTE null checks again

* Do address validation on EmitPtPointerLoad, and make it branchless

* PTC version increment

* Mask of pointer tag for exclusive access

* Move mask to the correct place

Co-authored-by: LDj3SNuD <35856442+LDj3SNuD@users.noreply.github.com>
---
 .../Instructions/InstEmitMemoryExHelper.cs    |  31 +----
 .../Instructions/InstEmitMemoryHelper.cs      | 128 ++++++------------
 ARMeilleure/Translation/PTC/Ptc.cs            |   2 +-
 Ryujinx.Cpu/MemoryManager.cs                  |  13 +-
 4 files changed, 52 insertions(+), 122 deletions(-)

diff --git a/ARMeilleure/Instructions/InstEmitMemoryExHelper.cs b/ARMeilleure/Instructions/InstEmitMemoryExHelper.cs
index 317e4276ee..15f5e2abc9 100644
--- a/ARMeilleure/Instructions/InstEmitMemoryExHelper.cs
+++ b/ARMeilleure/Instructions/InstEmitMemoryExHelper.cs
@@ -19,19 +19,8 @@ namespace ARMeilleure.Instructions
 
                 if (size == 4)
                 {
-                    Operand isUnalignedAddr = InstEmitMemoryHelper.EmitAddressCheck(context, address, size);
-
-                    Operand lblFastPath = Label();
-
-                    context.BranchIfFalse(lblFastPath, isUnalignedAddr);
-
-                    // The call is not expected to return (it should throw).
-                    context.Call(typeof(NativeInterface).GetMethod(nameof(NativeInterface.ThrowInvalidMemoryAccess)), address);
-
-                    context.MarkLabel(lblFastPath);
-
                     // Only 128-bit CAS is guaranteed to have a atomic load.
-                    Operand physAddr = InstEmitMemoryHelper.EmitPtPointerLoad(context, address, null, write: false);
+                    Operand physAddr = InstEmitMemoryHelper.EmitPtPointerLoad(context, address, null, write: false, 4);
 
                     Operand zero = context.VectorZero();
 
@@ -119,20 +108,8 @@ namespace ARMeilleure.Instructions
 
                 context.BranchIfTrue(lblExit, exFailed);
 
-                // STEP 2: We have exclusive access, make sure that the address is valid.
-                Operand isUnalignedAddr = InstEmitMemoryHelper.EmitAddressCheck(context, address, size);
-
-                Operand lblFastPath = Label();
-
-                context.BranchIfFalse(lblFastPath, isUnalignedAddr);
-
-                // The call is not expected to return (it should throw).
-                context.Call(typeof(NativeInterface).GetMethod(nameof(NativeInterface.ThrowInvalidMemoryAccess)), address);
-
-                // STEP 3: We have exclusive access and the address is valid, attempt the store using CAS.
-                context.MarkLabel(lblFastPath);
-
-                Operand physAddr = InstEmitMemoryHelper.EmitPtPointerLoad(context, address, null, write: true);
+                // STEP 2: We have exclusive access and the address is valid, attempt the store using CAS.
+                Operand physAddr = InstEmitMemoryHelper.EmitPtPointerLoad(context, address, null, write: true, size);
 
                 Operand exValuePtr = context.Add(arg0, Const((long)NativeContext.GetExclusiveValueOffset()));
                 Operand exValue = size switch
@@ -151,7 +128,7 @@ namespace ARMeilleure.Instructions
                     _ => context.CompareAndSwap(physAddr, exValue, value)
                 };
 
-                // STEP 4: Check if we succeeded by comparing expected and in-memory values.
+                // STEP 3: Check if we succeeded by comparing expected and in-memory values.
                 Operand storeFailed;
 
                 if (size == 4)
diff --git a/ARMeilleure/Instructions/InstEmitMemoryHelper.cs b/ARMeilleure/Instructions/InstEmitMemoryHelper.cs
index fd5c5bca30..cb4fae8f9b 100644
--- a/ARMeilleure/Instructions/InstEmitMemoryHelper.cs
+++ b/ARMeilleure/Instructions/InstEmitMemoryHelper.cs
@@ -127,11 +127,7 @@ namespace ARMeilleure.Instructions
             Operand lblSlowPath = Label();
             Operand lblEnd      = Label();
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            context.BranchIfTrue(lblSlowPath, isUnalignedAddr);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: false);
+            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: false, size);
 
             Operand value = null;
 
@@ -161,18 +157,7 @@ namespace ARMeilleure.Instructions
                 throw new ArgumentOutOfRangeException(nameof(size));
             }
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            Operand lblFastPath = Label();
-
-            context.BranchIfFalse(lblFastPath, isUnalignedAddr, BasicBlockFrequency.Cold);
-
-            // The call is not expected to return (it should throw).
-            context.Call(typeof(NativeInterface).GetMethod(nameof(NativeInterface.ThrowInvalidMemoryAccess)), address);
-
-            context.MarkLabel(lblFastPath);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, null, write: false);
+            Operand physAddr = EmitPtPointerLoad(context, address, null, write: false, size);
 
             return size switch
             {
@@ -195,11 +180,7 @@ namespace ARMeilleure.Instructions
             Operand lblSlowPath = Label();
             Operand lblEnd      = Label();
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            context.BranchIfTrue(lblSlowPath, isUnalignedAddr);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: false);
+            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: false, size);
 
             Operand value = null;
 
@@ -233,11 +214,7 @@ namespace ARMeilleure.Instructions
             Operand lblSlowPath = Label();
             Operand lblEnd      = Label();
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            context.BranchIfTrue(lblSlowPath, isUnalignedAddr);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: true);
+            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: true, size);
 
             Operand value = GetInt(context, rt);
 
@@ -270,18 +247,7 @@ namespace ARMeilleure.Instructions
                 throw new ArgumentOutOfRangeException(nameof(size));
             }
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            Operand lblFastPath = Label();
-
-            context.BranchIfFalse(lblFastPath, isUnalignedAddr, BasicBlockFrequency.Cold);
-
-            // The call is not expected to return (it should throw).
-            context.Call(typeof(NativeInterface).GetMethod(nameof(NativeInterface.ThrowInvalidMemoryAccess)), address);
-
-            context.MarkLabel(lblFastPath);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, null, write: true);
+            Operand physAddr = EmitPtPointerLoad(context, address, null, write: true, size);
 
             if (size < 3 && value.Type == OperandType.I64)
             {
@@ -312,11 +278,7 @@ namespace ARMeilleure.Instructions
             Operand lblSlowPath = Label();
             Operand lblEnd      = Label();
 
-            Operand isUnalignedAddr = EmitAddressCheck(context, address, size);
-
-            context.BranchIfTrue(lblSlowPath, isUnalignedAddr);
-
-            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: true);
+            Operand physAddr = EmitPtPointerLoad(context, address, lblSlowPath, write: true, size);
 
             Operand value = GetVec(rt);
 
@@ -338,61 +300,49 @@ namespace ARMeilleure.Instructions
             context.MarkLabel(lblEnd);
         }
 
-        public static Operand EmitAddressCheck(ArmEmitterContext context, Operand address, int size)
+        public static Operand EmitPtPointerLoad(ArmEmitterContext context, Operand address, Operand lblSlowPath, bool write, int size)
         {
-            ulong addressCheckMask = ~((1UL << context.Memory.AddressSpaceBits) - 1);
-
-            addressCheckMask |= (1u << size) - 1;
-
-            return context.BitwiseAnd(address, Const(address.Type, (long)addressCheckMask));
-        }
-
-        public static Operand EmitPtPointerLoad(ArmEmitterContext context, Operand address, Operand lblSlowPath, bool write)
-        {
-            int ptLevelBits = context.Memory.AddressSpaceBits - 12; // 12 = Number of page bits.
+            int ptLevelBits = context.Memory.AddressSpaceBits - PageBits;
             int ptLevelSize = 1 << ptLevelBits;
             int ptLevelMask = ptLevelSize - 1;
 
+            Operand addrRotated = size != 0 ? context.RotateRight(address, Const(size)) : address;
+            Operand addrShifted = context.ShiftRightUI(addrRotated, Const(PageBits - size));
+
             Operand pte = Ptc.State == PtcState.Disabled
                 ? Const(context.Memory.PageTablePointer.ToInt64())
                 : Const(context.Memory.PageTablePointer.ToInt64(), true, Ptc.PageTablePointerIndex);
 
-            int bit = PageBits;
+            Operand pteOffset = context.BitwiseAnd(addrShifted, Const(addrShifted.Type, ptLevelMask));
 
-            // Load page table entry from the page table.
-            // This was designed to support multi-level page tables of any size, however right
-            // now we only use flat page tables (so there's only one level).
-            // The page table entry contains the host address where the page is located.
-            // Additionally, the higher 16-bits of the host address may contain extra information
-            // used for write tracking, so this must be handled here aswell.
-            do
+            if (pteOffset.Type == OperandType.I32)
             {
-                Operand addrPart = context.ShiftRightUI(address, Const(bit));
-
-                bit += ptLevelBits;
-
-                if (bit < context.Memory.AddressSpaceBits)
-                {
-                    addrPart = context.BitwiseAnd(addrPart, Const(addrPart.Type, ptLevelMask));
-                }
-
-                Operand pteOffset = context.ShiftLeft(addrPart, Const(3));
-
-                if (pteOffset.Type == OperandType.I32)
-                {
-                    pteOffset = context.ZeroExtend32(OperandType.I64, pteOffset);
-                }
-
-                Operand pteAddress = context.Add(pte, pteOffset);
-
-                pte = context.Load(OperandType.I64, pteAddress);
+                pteOffset = context.ZeroExtend32(OperandType.I64, pteOffset);
             }
-            while (bit < context.Memory.AddressSpaceBits);
+
+            pte = context.Load(OperandType.I64, context.Add(pte, context.ShiftLeft(pteOffset, Const(3))));
+
+            if (addrShifted.Type == OperandType.I32)
+            {
+                addrShifted = context.ZeroExtend32(OperandType.I64, addrShifted);
+            }
+
+            // If the VA is out of range, or not aligned to the access size, force PTE to 0 by masking it.
+            pte = context.BitwiseAnd(pte, context.ShiftRightSI(context.Add(addrShifted, Const(-(long)ptLevelSize)), Const(63)));
 
             if (lblSlowPath != null)
             {
-                ulong protection = (write ? 3UL : 1UL) << 48;
-                context.BranchIfTrue(lblSlowPath, context.BitwiseAnd(pte, Const(protection)));
+                if (write)
+                {
+                    pte = context.ShiftLeft(pte, Const(1));
+                    context.BranchIf(lblSlowPath, pte, Const(0L), Comparison.LessOrEqual);
+                    pte = context.ShiftRightUI(pte, Const(1));
+                }
+                else
+                {
+                    context.BranchIf(lblSlowPath, pte, Const(0L), Comparison.LessOrEqual);
+                    pte = context.BitwiseAnd(pte, Const(0xffffffffffffUL)); // Ignore any software protection bits. (they are still used by C# memory access)
+                }
             }
             else
             {
@@ -401,13 +351,15 @@ namespace ARMeilleure.Instructions
 
                 Operand lblNotWatched = Label();
 
-                // Is the page currently being tracked for read/write? If so we need to call MarkRegionAsModified.
+                // Is the page currently being tracked for read/write? If so we need to call SignalMemoryTracking.
                 context.BranchIf(lblNotWatched, pte, Const(0L), Comparison.GreaterOrEqual, BasicBlockFrequency.Cold);
 
-                // Mark the region as modified. Size here doesn't matter as address is assumed to be size aligned here.
+                // Signal memory tracking. Size here doesn't matter as address is assumed to be size aligned here.
                 context.Call(typeof(NativeInterface).GetMethod(nameof(NativeInterface.SignalMemoryTracking)), address, Const(1UL), Const(write ? 1 : 0));
                 context.MarkLabel(lblNotWatched);
 
+                pte = context.BitwiseAnd(pte, Const(0xffffffffffffUL)); // Ignore any software protection bits. (they are still used by C# memory access)
+
                 Operand lblNonNull = Label();
 
                 // Skip exception if the PTE address is non-null (not zero).
@@ -418,8 +370,6 @@ namespace ARMeilleure.Instructions
                 context.MarkLabel(lblNonNull);
             }
 
-            pte = context.BitwiseAnd(pte, Const(0xffffffffffffUL)); // Ignore any software protection bits. (they are still used by c# memory access)
-
             Operand pageOffset = context.BitwiseAnd(address, Const(address.Type, PageMask));
 
             if (pageOffset.Type == OperandType.I32)
diff --git a/ARMeilleure/Translation/PTC/Ptc.cs b/ARMeilleure/Translation/PTC/Ptc.cs
index 846c01cde8..f3f209f0c3 100644
--- a/ARMeilleure/Translation/PTC/Ptc.cs
+++ b/ARMeilleure/Translation/PTC/Ptc.cs
@@ -22,7 +22,7 @@ namespace ARMeilleure.Translation.PTC
     {
         private const string HeaderMagic = "PTChd";
 
-        private const int InternalVersion = 1971; //! To be incremented manually for each change to the ARMeilleure project.
+        private const int InternalVersion = 1987; //! To be incremented manually for each change to the ARMeilleure project.
 
         private const string ActualDir = "0";
         private const string BackupDir = "1";
diff --git a/Ryujinx.Cpu/MemoryManager.cs b/Ryujinx.Cpu/MemoryManager.cs
index cef2012656..8c8bd3a4c2 100644
--- a/Ryujinx.Cpu/MemoryManager.cs
+++ b/Ryujinx.Cpu/MemoryManager.cs
@@ -21,6 +21,8 @@ namespace Ryujinx.Cpu
 
         private const int PteSize = 8;
 
+        private const int PointerTagBit = 62;
+
         private readonly InvalidAccessHandler _invalidAccessHandler;
 
         /// <summary>
@@ -556,11 +558,12 @@ namespace Ryujinx.Cpu
             // Protection is inverted on software pages, since the default value is 0.
             protection = (~protection) & MemoryPermission.ReadAndWrite;
 
-            long tag = (long)protection << 48;
-            if (tag > 0)
+            long tag = protection switch
             {
-                tag |= long.MinValue; // If any protection is present, the whole pte is negative.
-            }
+                MemoryPermission.None => 0L,
+                MemoryPermission.Read => 2L << PointerTagBit,
+                _ => 3L << PointerTagBit
+            };
 
             ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
             long invTagMask = ~(0xffffL << 48);
@@ -628,7 +631,7 @@ namespace Ryujinx.Cpu
             // tracking using host guard pages in future, but also supporting platforms where this is not possible.
 
             // Write tag includes read protection, since we don't have any read actions that aren't performed before write too.
-            long tag = (write ? 3L : 1L) << 48;
+            long tag = (write ? 3L : 2L) << PointerTagBit;
 
             ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;