From 45e520a27c2deb11d22a38c58962048d303460eb Mon Sep 17 00:00:00 2001 From: gdk Date: Thu, 23 Jun 2022 04:05:56 -0300 Subject: [PATCH] Rewrite PlaceholderManager4KB to use intrusive RBTree, and to coalesce free placeholders Also make the other placeholder manager use intrusive RBTree, allows the IntervalTree that was added just for this to be deleted --- Ryujinx.Memory/WindowsShared/IntervalTree.cs | 453 ------------------ Ryujinx.Memory/WindowsShared/MappingTree.cs | 69 +++ .../WindowsShared/PlaceholderManager.cs | 105 ++-- 3 files changed, 126 insertions(+), 501 deletions(-) delete mode 100644 Ryujinx.Memory/WindowsShared/IntervalTree.cs create mode 100644 Ryujinx.Memory/WindowsShared/MappingTree.cs diff --git a/Ryujinx.Memory/WindowsShared/IntervalTree.cs b/Ryujinx.Memory/WindowsShared/IntervalTree.cs deleted file mode 100644 index fe12e8b8e5..0000000000 --- a/Ryujinx.Memory/WindowsShared/IntervalTree.cs +++ /dev/null @@ -1,453 +0,0 @@ -using Ryujinx.Common.Collections; -using System; -using System.Collections.Generic; - -namespace Ryujinx.Memory.WindowsShared -{ - /// - /// An Augmented Interval Tree based off of the "TreeDictionary"'s Red-Black Tree. Allows fast overlap checking of ranges. - /// - /// Key - /// Value - class IntervalTree : IntrusiveRedBlackTreeImpl> where K : IComparable - { - private const int ArrayGrowthSize = 32; - - #region Public Methods - - /// - /// Gets the values of the interval whose key is . - /// - /// Key of the node value to get - /// Value with the given - /// True if the key is on the dictionary, false otherwise - public bool TryGet(K key, out V value) - { - IntervalTreeNode node = GetNode(key); - - if (node == null) - { - value = default; - return false; - } - - value = node.Value; - return true; - } - - /// - /// Returns the start addresses of the intervals whose start and end keys overlap the given range. - /// - /// Start of the range - /// End of the range - /// Overlaps array to place results in - /// Index to start writing results into the array. Defaults to 0 - /// Number of intervals found - public int Get(K start, K end, ref IntervalTreeNode[] overlaps, int overlapCount = 0) - { - GetNodes(Root, start, end, ref overlaps, ref overlapCount); - - return overlapCount; - } - - /// - /// Adds a new interval into the tree whose start is , end is and value is . - /// - /// Start of the range to add - /// End of the range to insert - /// Value to add - /// is null - public void Add(K start, K end, V value) - { - if (value == null) - { - throw new ArgumentNullException(nameof(value)); - } - - BSTInsert(start, end, value, null, out _); - } - - /// - /// Removes a value from the tree, searching for it with . - /// - /// Key of the node to remove - /// Number of deleted values - public int Remove(K key) - { - return Remove(GetNode(key)); - } - - /// - /// Removes a value from the tree, searching for it with . - /// - /// Node to be removed - /// Number of deleted values - public int Remove(IntervalTreeNode nodeToDelete) - { - if (nodeToDelete == null) - { - return 0; - } - - Delete(nodeToDelete); - - Count--; - - return 1; - } - - /// - /// Adds all the nodes in the dictionary into . - /// - /// A list of all values sorted by Key Order - public List AsList() - { - List list = new List(); - - AddToList(Root, list); - - return list; - } - - #endregion - - #region Private Methods (BST) - - /// - /// Adds all values that are children of or contained within into , in Key Order. - /// - /// The node to search for values within - /// The list to add values to - private void AddToList(IntervalTreeNode node, List list) - { - if (node == null) - { - return; - } - - AddToList(node.Left, list); - - list.Add(node.Value); - - AddToList(node.Right, list); - } - - /// - /// Retrieve the node reference whose key is , or null if no such node exists. - /// - /// Key of the node to get - /// is null - /// Node reference in the tree - private IntervalTreeNode GetNode(K key) - { - if (key == null) - { - throw new ArgumentNullException(nameof(key)); - } - - IntervalTreeNode node = Root; - while (node != null) - { - int cmp = key.CompareTo(node.Start); - if (cmp < 0) - { - node = node.Left; - } - else if (cmp > 0) - { - node = node.Right; - } - else - { - return node; - } - } - return null; - } - - /// - /// Retrieve all nodes that overlap the given start and end keys. - /// - /// Start of the range - /// End of the range - /// Overlaps array to place results in - /// Overlaps count to update - private void GetNodes(IntervalTreeNode node, K start, K end, ref IntervalTreeNode[] overlaps, ref int overlapCount) - { - if (node == null || start.CompareTo(node.Max) >= 0) - { - return; - } - - GetNodes(node.Left, start, end, ref overlaps, ref overlapCount); - - bool endsOnRight = end.CompareTo(node.Start) > 0; - if (endsOnRight) - { - if (start.CompareTo(node.End) < 0) - { - if (overlaps.Length >= overlapCount) - { - Array.Resize(ref overlaps, overlapCount + ArrayGrowthSize); - } - - overlaps[overlapCount++] = node; - } - - GetNodes(node.Right, start, end, ref overlaps, ref overlapCount); - } - } - - /// - /// Propagate an increase in max value starting at the given node, heading up the tree. - /// This should only be called if the max increases - not for rebalancing or removals. - /// - /// The node to start propagating from - private void PropagateIncrease(IntervalTreeNode node) - { - K max = node.Max; - IntervalTreeNode ptr = node; - - while ((ptr = ptr.Parent) != null) - { - if (max.CompareTo(ptr.Max) > 0) - { - ptr.Max = max; - } - else - { - break; - } - } - } - - /// - /// Propagate recalculating max value starting at the given node, heading up the tree. - /// This fully recalculates the max value from all children when there is potential for it to decrease. - /// - /// The node to start propagating from - private void PropagateFull(IntervalTreeNode node) - { - IntervalTreeNode ptr = node; - - do - { - K max = ptr.End; - - if (ptr.Left != null && ptr.Left.Max.CompareTo(max) > 0) - { - max = ptr.Left.Max; - } - - if (ptr.Right != null && ptr.Right.Max.CompareTo(max) > 0) - { - max = ptr.Right.Max; - } - - ptr.Max = max; - } while ((ptr = ptr.Parent) != null); - } - - /// - /// Insertion Mechanism for the interval tree. Similar to a BST insert, with the start of the range as the key. - /// Iterates the tree starting from the root and inserts a new node where all children in the left subtree are less than , and all children in the right subtree are greater than . - /// Each node can contain multiple values, and has an end address which is the maximum of all those values. - /// Post insertion, the "max" value of the node and all parents are updated. - /// - /// Start of the range to insert - /// End of the range to insert - /// Value to insert - /// Optional factory used to create a new value if is already on the tree - /// Node that was inserted or modified - /// True if was not yet on the tree, false otherwise - private bool BSTInsert(K start, K end, V value, Func updateFactoryCallback, out IntervalTreeNode outNode) - { - IntervalTreeNode parent = null; - IntervalTreeNode node = Root; - - while (node != null) - { - parent = node; - int cmp = start.CompareTo(node.Start); - if (cmp < 0) - { - node = node.Left; - } - else if (cmp > 0) - { - node = node.Right; - } - else - { - outNode = node; - - if (updateFactoryCallback != null) - { - // Replace - node.Value = updateFactoryCallback(start, node.Value); - - int endCmp = end.CompareTo(node.End); - - if (endCmp > 0) - { - node.End = end; - if (end.CompareTo(node.Max) > 0) - { - node.Max = end; - PropagateIncrease(node); - RestoreBalanceAfterInsertion(node); - } - } - else if (endCmp < 0) - { - node.End = end; - PropagateFull(node); - } - } - - return false; - } - } - IntervalTreeNode newNode = new IntervalTreeNode(start, end, value, parent); - if (newNode.Parent == null) - { - Root = newNode; - } - else if (start.CompareTo(parent.Start) < 0) - { - parent.Left = newNode; - } - else - { - parent.Right = newNode; - } - - PropagateIncrease(newNode); - Count++; - RestoreBalanceAfterInsertion(newNode); - outNode = newNode; - return true; - } - - /// - /// Removes the value from the dictionary after searching for it with . - /// - /// Tree node to be removed - private void Delete(IntervalTreeNode nodeToDelete) - { - IntervalTreeNode replacementNode; - - if (LeftOf(nodeToDelete) == null || RightOf(nodeToDelete) == null) - { - replacementNode = nodeToDelete; - } - else - { - replacementNode = nodeToDelete.Predecessor; - } - - IntervalTreeNode tmp = LeftOf(replacementNode) ?? RightOf(replacementNode); - - if (tmp != null) - { - tmp.Parent = ParentOf(replacementNode); - } - - if (ParentOf(replacementNode) == null) - { - Root = tmp; - } - else if (replacementNode == LeftOf(ParentOf(replacementNode))) - { - ParentOf(replacementNode).Left = tmp; - } - else - { - ParentOf(replacementNode).Right = tmp; - } - - if (replacementNode != nodeToDelete) - { - nodeToDelete.Start = replacementNode.Start; - nodeToDelete.Value = replacementNode.Value; - nodeToDelete.End = replacementNode.End; - nodeToDelete.Max = replacementNode.Max; - } - - PropagateFull(replacementNode); - - if (tmp != null && ColorOf(replacementNode) == Black) - { - RestoreBalanceAfterRemoval(tmp); - } - } - - #endregion - - #region Private Methods (RBL) - - protected override void RotateLeft(IntervalTreeNode node) - { - if (node != null) - { - base.RotateLeft(node); - - PropagateFull(node); - } - } - - protected override void RotateRight(IntervalTreeNode node) - { - if (node != null) - { - base.RotateRight(node); - - PropagateFull(node); - } - } - - #endregion - - public bool ContainsKey(K key) - { - return GetNode(key) != null; - } - } - - /// - /// Represents a node in the IntervalTree which contains start and end keys of type K, and a value of generic type V. - /// - /// Key type of the node - /// Value type of the node - class IntervalTreeNode : IntrusiveRedBlackTreeNode> - { - /// - /// The start of the range. - /// - public K Start; - - /// - /// The end of the range. - /// - public K End; - - /// - /// The maximum end value of this node and all its children. - /// - public K Max; - - /// - /// Value stored on this node. - /// - public V Value; - - public IntervalTreeNode(K start, K end, V value, IntervalTreeNode parent) - { - Start = start; - End = end; - Max = end; - Value = value; - Parent = parent; - } - } -} diff --git a/Ryujinx.Memory/WindowsShared/MappingTree.cs b/Ryujinx.Memory/WindowsShared/MappingTree.cs new file mode 100644 index 0000000000..8f880f0c84 --- /dev/null +++ b/Ryujinx.Memory/WindowsShared/MappingTree.cs @@ -0,0 +1,69 @@ +using Ryujinx.Common.Collections; +using System; + +namespace Ryujinx.Memory.WindowsShared +{ + /// + /// A intrusive Red-Black Tree that also supports getting nodes overlapping a given range. + /// + /// Type of the value stored on the node + class MappingTree : IntrusiveRedBlackTree> + { + public int GetNodes(ulong start, ulong end, ref RangeNode[] overlaps, int overlapCount = 0) + { + RangeNode node = GetNode(new RangeNode(start, start + 1UL, default)); + + for (; node != null; node = node.Successor) + { + if (overlaps.Length <= overlapCount) + { + Array.Resize(ref overlaps, overlapCount + 1); + } + + overlaps[overlapCount++] = node; + + if (node.End >= end) + { + break; + } + } + + return overlapCount; + } + } + + class RangeNode : IntrusiveRedBlackTreeNode>, IComparable> + { + public ulong Start { get; } + public ulong End { get; private set; } + public T Value { get; } + + public RangeNode(ulong start, ulong end, T value) + { + Start = start; + End = end; + Value = value; + } + + public void Extend(ulong sizeDelta) + { + End += sizeDelta; + } + + public int CompareTo(RangeNode other) + { + if (Start < other.Start) + { + return -1; + } + else if (Start <= other.End - 1UL) + { + return 0; + } + else + { + return 1; + } + } + } +} \ No newline at end of file diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs index 0937d46230..6db8d7df1c 100644 --- a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs +++ b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs @@ -13,10 +13,10 @@ namespace Ryujinx.Memory.WindowsShared [SupportedOSPlatform("windows")] class PlaceholderManager { - private const ulong MinimumPageSize = 0x1000; + private const int InitialOverlapsSize = 10; - private readonly IntervalTree _mappings; - private readonly IntervalTree _protections; + private readonly MappingTree _mappings; + private readonly MappingTree _protections; private readonly IntPtr _partialUnmapStatePtr; private readonly Thread _partialUnmapTrimThread; @@ -25,8 +25,8 @@ namespace Ryujinx.Memory.WindowsShared /// public PlaceholderManager() { - _mappings = new IntervalTree(); - _protections = new IntervalTree(); + _mappings = new MappingTree(); + _protections = new MappingTree(); _partialUnmapStatePtr = PartialUnmapState.GlobalState; @@ -67,7 +67,7 @@ namespace Ryujinx.Memory.WindowsShared { lock (_mappings) { - _mappings.Add(address, address + size, ulong.MaxValue); + _mappings.Add(new RangeNode(address, address + size, ulong.MaxValue)); } } @@ -81,12 +81,12 @@ namespace Ryujinx.Memory.WindowsShared { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(address, endAddress, ref overlaps); + count = _mappings.GetNodes(address, endAddress, ref overlaps); for (int index = 0; index < count; index++) { @@ -178,11 +178,11 @@ namespace Ryujinx.Memory.WindowsShared { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; lock (_mappings) { - int count = _mappings.Get(address, endAddress, ref overlaps); + int count = _mappings.GetNodes(address, endAddress, ref overlaps); Debug.Assert(count == 1); Debug.Assert(!IsMapped(overlaps[0].Value)); @@ -206,8 +206,8 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)size, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(overlapStart, address, overlapValue); - _mappings.Add(endAddress, overlapEnd, AddBackingOffset(overlapValue, endAddress - overlapStart)); + _mappings.Add(new RangeNode(overlapStart, address, overlapValue)); + _mappings.Add(new RangeNode(endAddress, overlapEnd, AddBackingOffset(overlapValue, endAddress - overlapStart))); } else if (overlapStartsBefore) { @@ -218,7 +218,7 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)overlappedSize, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(overlapStart, address, overlapValue); + _mappings.Add(new RangeNode(overlapStart, address, overlapValue)); } else if (overlapEndsAfter) { @@ -229,10 +229,10 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)overlappedSize, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(endAddress, overlapEnd, AddBackingOffset(overlapValue, overlappedSize)); + _mappings.Add(new RangeNode(endAddress, overlapEnd, AddBackingOffset(overlapValue, overlappedSize))); } - _mappings.Add(address, endAddress, backingOffset); + _mappings.Add(new RangeNode(address, endAddress, backingOffset)); } } @@ -280,12 +280,12 @@ namespace Ryujinx.Memory.WindowsShared ulong unmapSize = (ulong)size; ulong endAddress = startAddress + unmapSize; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(startAddress, endAddress, ref overlaps); + count = _mappings.GetNodes(startAddress, endAddress, ref overlaps); } for (int index = 0; index < count; index++) @@ -302,7 +302,7 @@ namespace Ryujinx.Memory.WindowsShared lock (_mappings) { _mappings.Remove(overlap); - _mappings.Add(overlapStart, overlapEnd, ulong.MaxValue); + _mappings.Add(new RangeNode(overlapStart, overlapEnd, ulong.MaxValue)); } bool overlapStartsBefore = overlapStart < startAddress; @@ -374,44 +374,53 @@ namespace Ryujinx.Memory.WindowsShared ulong endAddress = address + size; ulong blockAddress = (ulong)owner.Pointer; ulong blockEnd = blockAddress + owner.Size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int unmappedCount = 0; lock (_mappings) { - int count = _mappings.Get( - Math.Max(address - MinimumPageSize, blockAddress), - Math.Min(endAddress + MinimumPageSize, blockEnd), ref overlaps); + int count = _mappings.GetNodes(address, endAddress, ref overlaps); - if (count < 2) + if (count == 0) { - // Nothing to coalesce if we only have 1 or no overlaps. + // Nothing to coalesce if we no overlaps. return; } + RangeNode predecessor = overlaps[0].Predecessor; + RangeNode successor = overlaps[count - 1].Successor; + for (int index = 0; index < count; index++) { var overlap = overlaps[index]; if (!IsMapped(overlap.Value)) { - if (address > overlap.Start) - { - address = overlap.Start; - } - - if (endAddress < overlap.End) - { - endAddress = overlap.End; - } + address = Math.Min(address, overlap.Start); + endAddress = Math.Max(endAddress, overlap.End); _mappings.Remove(overlap); - unmappedCount++; } } - _mappings.Add(address, endAddress, ulong.MaxValue); + if (predecessor != null && !IsMapped(predecessor.Value) && predecessor.Start >= blockAddress) + { + address = Math.Min(address, predecessor.Start); + + _mappings.Remove(predecessor); + unmappedCount++; + } + + if (successor != null && !IsMapped(successor.Value) && successor.End <= blockEnd) + { + endAddress = Math.Max(endAddress, successor.End); + + _mappings.Remove(successor); + unmappedCount++; + } + + _mappings.Add(new RangeNode(address, endAddress, ulong.MaxValue)); } if (unmappedCount > 1) @@ -462,12 +471,12 @@ namespace Ryujinx.Memory.WindowsShared ulong reprotectSize = (ulong)size; ulong endAddress = reprotectAddress + reprotectSize; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(reprotectAddress, endAddress, ref overlaps); + count = _mappings.GetNodes(reprotectAddress, endAddress, ref overlaps); } bool success = true; @@ -567,12 +576,12 @@ namespace Ryujinx.Memory.WindowsShared private void AddProtection(ulong address, ulong size, MemoryPermission permission) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); if (count == 1 && overlaps[0].Start <= address && @@ -610,17 +619,17 @@ namespace Ryujinx.Memory.WindowsShared { if (startAddress > protAddress) { - _protections.Add(protAddress, startAddress, protPermission); + _protections.Add(new RangeNode(protAddress, startAddress, protPermission)); } if (endAddress < protEndAddress) { - _protections.Add(endAddress, protEndAddress, protPermission); + _protections.Add(new RangeNode(endAddress, protEndAddress, protPermission)); } } } - _protections.Add(startAddress, endAddress, permission); + _protections.Add(new RangeNode(startAddress, endAddress, permission)); } } @@ -632,12 +641,12 @@ namespace Ryujinx.Memory.WindowsShared private void RemoveProtection(ulong address, ulong size) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); for (int index = 0; index < count; index++) { @@ -651,12 +660,12 @@ namespace Ryujinx.Memory.WindowsShared if (address > protAddress) { - _protections.Add(protAddress, address, protPermission); + _protections.Add(new RangeNode(protAddress, address, protPermission)); } if (endAddress < protEndAddress) { - _protections.Add(endAddress, protEndAddress, protPermission); + _protections.Add(new RangeNode(endAddress, protEndAddress, protPermission)); } } } @@ -670,12 +679,12 @@ namespace Ryujinx.Memory.WindowsShared private void RestoreRangeProtection(ulong address, ulong size) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); } ulong startAddress = address;