From 4905101df1b3dcb19682a2f9e83c81afb0627003 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Wed, 30 Nov 2022 18:24:15 -0300
Subject: [PATCH] Remove shader dependency on SPV_KHR_shader_ballot and
 SPV_KHR_subgroup_vote extensions (#3943)

* Remove shader dependency on SPV_KHR_shader_ballot and SPV_KHR_subgroup_vote extensions

* Shader cache version bump
---
 .../Shader/DiskCache/DiskCacheHostStorage.cs  |  2 +-
 .../CodeGen/Spirv/Instructions.cs             | 19 +++++++++++--------
 .../CodeGen/Spirv/SpirvGenerator.cs           |  7 ++-----
 .../VulkanInitialization.cs                   |  1 -
 4 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
index 1d4842aadc..617e47652d 100644
--- a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs
@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
         private const ushort FileFormatVersionMajor = 1;
         private const ushort FileFormatVersionMinor = 2;
         private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
-        private const uint CodeGenVersion = 3897;
+        private const uint CodeGenVersion = 3943;
 
         private const string SharedTocFileName = "shared.toc";
         private const string SharedDataFileName = "shared.data";
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
index ea83061ec2..d4a3102e22 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs
@@ -234,7 +234,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var source = operation.GetSource(0);
 
             var uvec4Type = context.TypeVector(context.TypeU32(), 4);
-            var execution = context.Constant(context.TypeU32(), 3); // Subgroup
+            var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
 
             var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source));
             var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)0);
@@ -1233,7 +1233,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
             var srcThreadId = context.BitwiseOr(context.TypeU32(), indexNotSegMask, minThreadId);
             var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
-            var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+            var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
             var result = context.Select(context.TypeFP32(), valid, value, x);
 
             var validLocal = (AstOperand)operation.GetSource(3);
@@ -1263,7 +1263,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
             var srcThreadId = context.IAdd(context.TypeU32(), threadId, index);
             var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
-            var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+            var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
             var result = context.Select(context.TypeFP32(), valid, value, x);
 
             var validLocal = (AstOperand)operation.GetSource(3);
@@ -1289,7 +1289,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
             var srcThreadId = context.ISub(context.TypeU32(), threadId, index);
             var valid = context.SGreaterThanEqual(context.TypeBool(), srcThreadId, minThreadId);
-            var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+            var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
             var result = context.Select(context.TypeFP32(), valid, value, x);
 
             var validLocal = (AstOperand)operation.GetSource(3);
@@ -1319,7 +1319,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
             var srcThreadId = context.BitwiseXor(context.TypeU32(), threadId, index);
             var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
-            var value = context.SubgroupReadInvocationKHR(context.TypeFP32(), x, srcThreadId);
+            var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
             var result = context.Select(context.TypeFP32(), valid, value, x);
 
             var validLocal = (AstOperand)operation.GetSource(3);
@@ -1861,19 +1861,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
 
         private static OperationResult GenerateVoteAll(CodeGenContext context, AstOperation operation)
         {
-            var result = context.SubgroupAllKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+            var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
+            var result = context.GroupNonUniformAll(context.TypeBool(), execution, context.Get(AggregateType.Bool, operation.GetSource(0)));
             return new OperationResult(AggregateType.Bool, result);
         }
 
         private static OperationResult GenerateVoteAllEqual(CodeGenContext context, AstOperation operation)
         {
-            var result = context.SubgroupAllEqualKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+            var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
+            var result = context.GroupNonUniformAllEqual(context.TypeBool(), execution, context.Get(AggregateType.Bool, operation.GetSource(0)));
             return new OperationResult(AggregateType.Bool, result);
         }
 
         private static OperationResult GenerateVoteAny(CodeGenContext context, AstOperation operation)
         {
-            var result = context.SubgroupAnyKHR(context.TypeBool(), context.Get(AggregateType.Bool, operation.GetSource(0)));
+            var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
+            var result = context.GroupNonUniformAny(context.TypeBool(), execution, context.Get(AggregateType.Bool, operation.GetSource(0)));
             return new OperationResult(AggregateType.Bool, result);
         }
 
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
index 95df077bfd..6e1db972d4 100644
--- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
+++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs
@@ -50,12 +50,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             CodeGenContext context = new CodeGenContext(info, config, instPool, integerPool);
 
             context.AddCapability(Capability.GroupNonUniformBallot);
+            context.AddCapability(Capability.GroupNonUniformShuffle);
+            context.AddCapability(Capability.GroupNonUniformVote);
             context.AddCapability(Capability.ImageBuffer);
             context.AddCapability(Capability.ImageGatherExtended);
             context.AddCapability(Capability.ImageQuery);
             context.AddCapability(Capability.SampledBuffer);
-            context.AddCapability(Capability.SubgroupBallotKHR);
-            context.AddCapability(Capability.SubgroupVoteKHR);
 
             if (config.TransformFeedbackEnabled && config.LastInVertexPipeline)
             {
@@ -94,9 +94,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                 context.AddCapability(Capability.DrawParameters);
             }
 
-            context.AddExtension("SPV_KHR_shader_ballot");
-            context.AddExtension("SPV_KHR_subgroup_vote");
-
             Declarations.DeclareAll(context, info);
 
             if ((info.HelperFunctionsMask & NeedsInvocationIdMask) != 0)
diff --git a/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs b/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs
index 48f80ad485..190221c78f 100644
--- a/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs
+++ b/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs
@@ -37,7 +37,6 @@ namespace Ryujinx.Graphics.Vulkan
         public static string[] RequiredExtensions { get; } = new string[]
         {
             KhrSwapchain.ExtensionName,
-            "VK_EXT_shader_subgroup_vote",
             ExtTransformFeedback.ExtensionName
         };