From 8b3eba7e1333ca69e55e3ca85a77d3dd4205e991 Mon Sep 17 00:00:00 2001
From: FICTURE7 <FICTURE7@gmail.com>
Date: Fri, 2 Apr 2021 21:26:16 +0400
Subject: [PATCH] Reduce allocation during SSA construction (#2162)

* Reduce allocation during SSA construction

* Re-trigger CI
---
 .../IntermediateRepresentation/Operand.cs     |   2 +
 ARMeilleure/Translation/SsaConstruction.cs    | 236 ++++++++----------
 2 files changed, 111 insertions(+), 127 deletions(-)

diff --git a/ARMeilleure/IntermediateRepresentation/Operand.cs b/ARMeilleure/IntermediateRepresentation/Operand.cs
index b8650d5a9d..7b486c55d5 100644
--- a/ARMeilleure/IntermediateRepresentation/Operand.cs
+++ b/ARMeilleure/IntermediateRepresentation/Operand.cs
@@ -1,5 +1,6 @@
 using System;
 using System.Collections.Generic;
+using System.Runtime.CompilerServices;
 
 namespace ARMeilleure.IntermediateRepresentation
 {
@@ -84,6 +85,7 @@ namespace ARMeilleure.IntermediateRepresentation
             return With(OperandKind.Register, type, (ulong)((int)regType << 24 | index));
         }
 
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public Register GetRegister()
         {
             return new Register((int)Value & 0xffffff, (RegisterType)(Value >> 24));
diff --git a/ARMeilleure/Translation/SsaConstruction.cs b/ARMeilleure/Translation/SsaConstruction.cs
index 46435f4446..1c6e83c9ce 100644
--- a/ARMeilleure/Translation/SsaConstruction.cs
+++ b/ARMeilleure/Translation/SsaConstruction.cs
@@ -1,8 +1,9 @@
 using ARMeilleure.Common;
 using ARMeilleure.IntermediateRepresentation;
 using ARMeilleure.State;
+using System;
 using System.Collections.Generic;
-
+using System.Diagnostics;
 using static ARMeilleure.IntermediateRepresentation.OperandHelper;
 
 namespace ARMeilleure.Translation
@@ -11,104 +12,92 @@ namespace ARMeilleure.Translation
     {
         private class DefMap
         {
-            private Dictionary<Register, Operand> _map;
-
-            private BitMap _phiMasks;
+            private readonly Dictionary<int, Operand> _map;
+            private readonly BitMap _phiMasks;
 
             public DefMap()
             {
-                _map = new Dictionary<Register, Operand>();
-
+                _map = new Dictionary<int, Operand>();
                 _phiMasks = new BitMap(RegisterConsts.TotalCount);
             }
 
-            public bool TryAddOperand(Register reg, Operand operand)
+            public bool TryAddOperand(int key, Operand operand)
             {
-                return _map.TryAdd(reg, operand);
+                return _map.TryAdd(key, operand);
             }
 
-            public bool TryGetOperand(Register reg, out Operand operand)
+            public bool TryGetOperand(int key, out Operand operand)
             {
-                return _map.TryGetValue(reg, out operand);
+                return _map.TryGetValue(key, out operand);
             }
 
-            public bool AddPhi(Register reg)
+            public bool AddPhi(int key)
             {
-                return _phiMasks.Set(GetIdFromRegister(reg));
+                return _phiMasks.Set(key);
             }
 
-            public bool HasPhi(Register reg)
+            public bool HasPhi(int key)
             {
-                return _phiMasks.IsSet(GetIdFromRegister(reg));
+                return _phiMasks.IsSet(key);
             }
         }
 
         public static void Construct(ControlFlowGraph cfg)
         {
-            DefMap[] globalDefs = new DefMap[cfg.Blocks.Count];
+            var globalDefs = new DefMap[cfg.Blocks.Count];
+            var localDefs = new Operand[RegisterConsts.TotalCount];
+
+            var dfPhiBlocks = new Queue<BasicBlock>();
 
             for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext)
             {
                 globalDefs[block.Index] = new DefMap();
             }
 
-            Queue<BasicBlock> dfPhiBlocks = new Queue<BasicBlock>();
-
             // First pass, get all defs and locals uses.
             for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext)
             {
-                Operand[] localDefs = new Operand[RegisterConsts.TotalCount];
-
-                Node node = block.Operations.First;
-
-                Operand RenameLocal(Operand operand)
+                for (Node node = block.Operations.First; node != null; node = node.ListNext)
                 {
-                    if (operand != null && operand.Kind == OperandKind.Register)
-                    {
-                        Operand local = localDefs[GetIdFromRegister(operand.GetRegister())];
-
-                        operand = local ?? operand;
-                    }
-
-                    return operand;
-                }
-
-                while (node != null)
-                {
-                    if (node is Operation operation)
-                    {
-                        for (int index = 0; index < operation.SourcesCount; index++)
-                        {
-                            operation.SetSource(index, RenameLocal(operation.GetSource(index)));
-                        }
-
-                        Operand dest = operation.Destination;
-
-                        if (dest != null && dest.Kind == OperandKind.Register)
-                        {
-                            Operand local = Local(dest.Type);
-
-                            localDefs[GetIdFromRegister(dest.GetRegister())] = local;
-
-                            operation.Destination = local;
-                        }
-                    }
-
-                    node = node.ListNext;
-                }
-
-                for (int index = 0; index < RegisterConsts.TotalCount; index++)
-                {
-                    Operand local = localDefs[index];
-
-                    if (local == null)
+                    if (node is not Operation operation)
                     {
                         continue;
                     }
 
-                    Register reg = GetRegisterFromId(index);
+                    for (int index = 0; index < operation.SourcesCount; index++)
+                    {
+                        Operand src = operation.GetSource(index);
 
-                    globalDefs[block.Index].TryAddOperand(reg, local);
+                        if (TryGetId(src, out int srcKey))
+                        {
+                            Operand local = localDefs[srcKey] ?? src;
+
+                            operation.SetSource(index, local);
+                        }
+                    }
+
+                    Operand dest = operation.Destination;
+
+                    if (TryGetId(dest, out int destKey))
+                    {
+                        Operand local = Local(dest.Type);
+
+                        localDefs[destKey] = local;
+
+                        operation.Destination = local;
+                    }
+                }
+
+                for (int key = 0; key < localDefs.Length; key++)
+                {
+                    Operand local = localDefs[key];
+
+                    if (local is null)
+                    {
+                        continue;
+                    }
+
+                    globalDefs[block.Index].TryAddOperand(key, local);
 
                     dfPhiBlocks.Enqueue(block);
 
@@ -116,61 +105,53 @@ namespace ARMeilleure.Translation
                     {
                         foreach (BasicBlock domFrontier in dfPhiBlock.DominanceFrontiers)
                         {
-                            if (globalDefs[domFrontier.Index].AddPhi(reg))
+                            if (globalDefs[domFrontier.Index].AddPhi(key))
                             {
                                 dfPhiBlocks.Enqueue(domFrontier);
                             }
                         }
                     }
                 }
+
+                Array.Clear(localDefs, 0, localDefs.Length);
             }
 
             // Second pass, rename variables with definitions on different blocks.
             for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext)
             {
-                Operand[] localDefs = new Operand[RegisterConsts.TotalCount];
-
-                Node node = block.Operations.First;
-
-                Operand RenameGlobal(Operand operand)
+                for (Node node = block.Operations.First; node != null; node = node.ListNext)
                 {
-                    if (operand != null && operand.Kind == OperandKind.Register)
+                    if (node is not Operation operation)
                     {
-                        int key = GetIdFromRegister(operand.GetRegister());
-
-                        Operand local = localDefs[key];
-
-                        if (local == null)
-                        {
-                            local = FindDef(globalDefs, block, operand);
-
-                            localDefs[key] = local;
-                        }
-
-                        operand = local;
+                        continue;
                     }
 
-                    return operand;
-                }
-
-                while (node != null)
-                {
-                    if (node is Operation operation)
+                    for (int index = 0; index < operation.SourcesCount; index++)
                     {
-                        for (int index = 0; index < operation.SourcesCount; index++)
+                        Operand src = operation.GetSource(index);
+
+                        if (TryGetId(src, out int key))
                         {
-                            operation.SetSource(index, RenameGlobal(operation.GetSource(index)));
+                            Operand local = localDefs[key];
+
+                            if (local is null)
+                            {
+                                local = FindDef(globalDefs, block, src);
+                                localDefs[key] = local;
+                            }
+
+                            operation.SetSource(index, local);
                         }
                     }
-
-                    node = node.ListNext;
                 }
+
+                Array.Clear(localDefs, 0, localDefs.Length);
             }
         }
 
         private static Operand FindDef(DefMap[] globalDefs, BasicBlock current, Operand operand)
         {
-            if (globalDefs[current.Index].HasPhi(operand.GetRegister()))
+            if (globalDefs[current.Index].HasPhi(GetId(operand)))
             {
                 return InsertPhi(globalDefs, current, operand);
             }
@@ -191,14 +172,14 @@ namespace ARMeilleure.Translation
             {
                 DefMap defMap = globalDefs[current.Index];
 
-                Register reg = operand.GetRegister();
+                int key = GetId(operand);
 
-                if (defMap.TryGetOperand(reg, out Operand lastDef))
+                if (defMap.TryGetOperand(key, out Operand lastDef))
                 {
                     return lastDef;
                 }
 
-                if (defMap.HasPhi(reg))
+                if (defMap.HasPhi(key))
                 {
                     return InsertPhi(globalDefs, current, operand);
                 }
@@ -223,7 +204,7 @@ namespace ARMeilleure.Translation
 
             AddPhi(block, phi);
 
-            globalDefs[block.Index].TryAddOperand(operand.GetRegister(), local);
+            globalDefs[block.Index].TryAddOperand(GetId(operand), local);
 
             for (int index = 0; index < block.Predecessors.Count; index++)
             {
@@ -258,44 +239,45 @@ namespace ARMeilleure.Translation
             }
         }
 
-        private static int GetIdFromRegister(Register reg)
+        private static bool TryGetId(Operand operand, out int result)
         {
-            if (reg.Type == RegisterType.Integer)
+            if (operand is { Kind: OperandKind.Register })
             {
-                return reg.Index;
-            }
-            else if (reg.Type == RegisterType.Vector)
-            {
-                return RegisterConsts.IntRegsCount + reg.Index;
-            }
-            else if (reg.Type == RegisterType.Flag)
-            {
-                return RegisterConsts.IntAndVecRegsCount + reg.Index;
-            }
-            else /* if (reg.Type == RegisterType.FpFlag) */
-            {
-                return RegisterConsts.FpFlagsOffset + reg.Index;
+                Register reg = operand.GetRegister();
+
+                if (reg.Type == RegisterType.Integer)
+                {
+                    result = reg.Index;
+                }
+                else if (reg.Type == RegisterType.Vector)
+                {
+                    result = RegisterConsts.IntRegsCount + reg.Index;
+                }
+                else if (reg.Type == RegisterType.Flag)
+                {
+                    result = RegisterConsts.IntAndVecRegsCount + reg.Index;
+                }
+                else /* if (reg.Type == RegisterType.FpFlag) */
+                {
+                    result = RegisterConsts.FpFlagsOffset + reg.Index;
+                }
+
+                return true;
             }
+
+            result = -1;
+
+            return false;
         }
 
-        private static Register GetRegisterFromId(int id)
+        private static int GetId(Operand operand)
         {
-            if (id < RegisterConsts.IntRegsCount)
+            if (!TryGetId(operand, out int key))
             {
-                return new Register(id, RegisterType.Integer);
-            }
-            else if (id < RegisterConsts.IntAndVecRegsCount)
-            {
-                return new Register(id - RegisterConsts.IntRegsCount, RegisterType.Vector);
-            }
-            else if (id < RegisterConsts.FpFlagsOffset)
-            {
-                return new Register(id - RegisterConsts.IntAndVecRegsCount, RegisterType.Flag);
-            }
-            else /* if (id < RegisterConsts.TotalCount) */
-            {
-                return new Register(id - RegisterConsts.FpFlagsOffset, RegisterType.FpFlag);
+                Debug.Fail("OperandKind must be Register.");
             }
+
+            return key;
         }
     }
 }
\ No newline at end of file