From b8d992e5a770931382fd39108601b0abe75149cc Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Fri, 26 Jan 2024 13:58:57 -0300
Subject: [PATCH] Allow skipping draws with broken pipeline variants on Vulkan
 (#5807)

* Allow skipping draws with broken pipeline variants on Vulkan

* Move IsLinked check to CreatePipeline

* Restore throw on error behaviour for background compile

* Can't remove SetAlphaTest pragmas yet

* Double new line
---
 src/Ryujinx.Graphics.Vulkan/PipelineBase.cs   | 122 ++++++++++++------
 src/Ryujinx.Graphics.Vulkan/PipelineFull.cs   |   5 +-
 src/Ryujinx.Graphics.Vulkan/PipelineState.cs  |  28 ++--
 .../ShaderCollection.cs                       |   4 +-
 .../VulkanException.cs                        |   8 +-
 5 files changed, 110 insertions(+), 57 deletions(-)

diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
index af3a27e556..61215b672e 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
@@ -34,7 +34,8 @@ namespace Ryujinx.Graphics.Vulkan
 
         protected PipelineDynamicState DynamicState;
         private PipelineState _newState;
-        private bool _stateDirty;
+        private bool _graphicsStateDirty;
+        private bool _computeStateDirty;
         private PrimitiveTopology _topology;
 
         private ulong _currentPipelineHandle;
@@ -353,7 +354,7 @@ namespace Ryujinx.Graphics.Vulkan
             }
 
             EndRenderPass();
-            RecreatePipelineIfNeeded(PipelineBindPoint.Compute);
+            RecreateComputePipelineIfNeeded();
 
             Gd.Api.CmdDispatch(CommandBuffer, (uint)groupsX, (uint)groupsY, (uint)groupsZ);
         }
@@ -366,19 +367,23 @@ namespace Ryujinx.Graphics.Vulkan
             }
 
             EndRenderPass();
-            RecreatePipelineIfNeeded(PipelineBindPoint.Compute);
+            RecreateComputePipelineIfNeeded();
 
             Gd.Api.CmdDispatchIndirect(CommandBuffer, indirectBuffer.Get(Cbs, indirectBufferOffset, 12).Value, (ulong)indirectBufferOffset);
         }
 
         public void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance)
         {
-            if (!_program.IsLinked || vertexCount == 0)
+            if (vertexCount == 0)
+            {
+                return;
+            }
+
+            if (!RecreateGraphicsPipelineIfNeeded())
             {
                 return;
             }
 
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
             BeginRenderPass();
             DrawCount++;
 
@@ -437,13 +442,18 @@ namespace Ryujinx.Graphics.Vulkan
 
         public void DrawIndexed(int indexCount, int instanceCount, int firstIndex, int firstVertex, int firstInstance)
         {
-            if (!_program.IsLinked || indexCount == 0)
+            if (indexCount == 0)
             {
                 return;
             }
 
             UpdateIndexBufferPattern();
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
+
+            if (!RecreateGraphicsPipelineIfNeeded())
+            {
+                return;
+            }
+
             BeginRenderPass();
             DrawCount++;
 
@@ -476,17 +486,17 @@ namespace Ryujinx.Graphics.Vulkan
 
         public void DrawIndexedIndirect(BufferRange indirectBuffer)
         {
-            if (!_program.IsLinked)
-            {
-                return;
-            }
-
             var buffer = Gd.BufferManager
                 .GetBuffer(CommandBuffer, indirectBuffer.Handle, indirectBuffer.Offset, indirectBuffer.Size, false)
                 .Get(Cbs, indirectBuffer.Offset, indirectBuffer.Size).Value;
 
             UpdateIndexBufferPattern();
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
+
+            if (!RecreateGraphicsPipelineIfNeeded())
+            {
+                return;
+            }
+
             BeginRenderPass();
             DrawCount++;
 
@@ -522,11 +532,6 @@ namespace Ryujinx.Graphics.Vulkan
 
         public void DrawIndexedIndirectCount(BufferRange indirectBuffer, BufferRange parameterBuffer, int maxDrawCount, int stride)
         {
-            if (!_program.IsLinked)
-            {
-                return;
-            }
-
             var countBuffer = Gd.BufferManager
                 .GetBuffer(CommandBuffer, parameterBuffer.Handle, parameterBuffer.Offset, parameterBuffer.Size, false)
                 .Get(Cbs, parameterBuffer.Offset, parameterBuffer.Size).Value;
@@ -536,7 +541,12 @@ namespace Ryujinx.Graphics.Vulkan
                 .Get(Cbs, indirectBuffer.Offset, indirectBuffer.Size).Value;
 
             UpdateIndexBufferPattern();
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
+
+            if (!RecreateGraphicsPipelineIfNeeded())
+            {
+                return;
+            }
+
             BeginRenderPass();
             DrawCount++;
 
@@ -614,18 +624,17 @@ namespace Ryujinx.Graphics.Vulkan
 
         public void DrawIndirect(BufferRange indirectBuffer)
         {
-            if (!_program.IsLinked)
-            {
-                return;
-            }
-
             // TODO: Support quads and other unsupported topologies.
 
             var buffer = Gd.BufferManager
                 .GetBuffer(CommandBuffer, indirectBuffer.Handle, indirectBuffer.Offset, indirectBuffer.Size, false)
                 .Get(Cbs, indirectBuffer.Offset, indirectBuffer.Size, false).Value;
 
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
+            if (!RecreateGraphicsPipelineIfNeeded())
+            {
+                return;
+            }
+
             BeginRenderPass();
             ResumeTransformFeedbackInternal();
             DrawCount++;
@@ -641,11 +650,6 @@ namespace Ryujinx.Graphics.Vulkan
                 throw new NotSupportedException();
             }
 
-            if (!_program.IsLinked)
-            {
-                return;
-            }
-
             var buffer = Gd.BufferManager
                 .GetBuffer(CommandBuffer, indirectBuffer.Handle, indirectBuffer.Offset, indirectBuffer.Size, false)
                 .Get(Cbs, indirectBuffer.Offset, indirectBuffer.Size, false).Value;
@@ -656,7 +660,11 @@ namespace Ryujinx.Graphics.Vulkan
 
             // TODO: Support quads and other unsupported topologies.
 
-            RecreatePipelineIfNeeded(PipelineBindPoint.Graphics);
+            if (!RecreateGraphicsPipelineIfNeeded())
+            {
+                return;
+            }
+
             BeginRenderPass();
             ResumeTransformFeedbackInternal();
             DrawCount++;
@@ -1576,10 +1584,23 @@ namespace Ryujinx.Graphics.Vulkan
 
         protected void SignalStateChange()
         {
-            _stateDirty = true;
+            _graphicsStateDirty = true;
+            _computeStateDirty = true;
         }
 
-        private void RecreatePipelineIfNeeded(PipelineBindPoint pbp)
+        private void RecreateComputePipelineIfNeeded()
+        {
+            if (_computeStateDirty || Pbp != PipelineBindPoint.Compute)
+            {
+                CreatePipeline(PipelineBindPoint.Compute);
+                _computeStateDirty = false;
+                Pbp = PipelineBindPoint.Compute;
+            }
+
+            _descriptorSetUpdater.UpdateAndBindDescriptorSets(Cbs, PipelineBindPoint.Compute);
+        }
+
+        private bool RecreateGraphicsPipelineIfNeeded()
         {
             if (AutoFlush.ShouldFlushDraw(DrawCount))
             {
@@ -1620,17 +1641,23 @@ namespace Ryujinx.Graphics.Vulkan
                 _vertexBufferUpdater.Commit(Cbs);
             }
 
-            if (_stateDirty || Pbp != pbp)
+            if (_graphicsStateDirty || Pbp != PipelineBindPoint.Graphics)
             {
-                CreatePipeline(pbp);
-                _stateDirty = false;
-                Pbp = pbp;
+                if (!CreatePipeline(PipelineBindPoint.Graphics))
+                {
+                    return false;
+                }
+
+                _graphicsStateDirty = false;
+                Pbp = PipelineBindPoint.Graphics;
             }
 
-            _descriptorSetUpdater.UpdateAndBindDescriptorSets(Cbs, pbp);
+            _descriptorSetUpdater.UpdateAndBindDescriptorSets(Cbs, PipelineBindPoint.Graphics);
+
+            return true;
         }
 
-        private void CreatePipeline(PipelineBindPoint pbp)
+        private bool CreatePipeline(PipelineBindPoint pbp)
         {
             // We can only create a pipeline if the have the shader stages set.
             if (_newState.Stages != null)
@@ -1640,10 +1667,25 @@ namespace Ryujinx.Graphics.Vulkan
                     CreateRenderPass();
                 }
 
+                if (!_program.IsLinked)
+                {
+                    // Background compile failed, we likely can't create the pipeline because the shader is broken
+                    // or the driver failed to compile it.
+
+                    return false;
+                }
+
                 var pipeline = pbp == PipelineBindPoint.Compute
                     ? _newState.CreateComputePipeline(Gd, Device, _program, PipelineCache)
                     : _newState.CreateGraphicsPipeline(Gd, Device, _program, PipelineCache, _renderPass.Get(Cbs).Value);
 
+                if (pipeline == null)
+                {
+                    // Host failed to create the pipeline, likely due to driver bugs.
+
+                    return false;
+                }
+
                 ulong pipelineHandle = pipeline.GetUnsafe().Value.Handle;
 
                 if (_currentPipelineHandle != pipelineHandle)
@@ -1655,6 +1697,8 @@ namespace Ryujinx.Graphics.Vulkan
                     Gd.Api.CmdBindPipeline(CommandBuffer, pbp, Pipeline.Get(Cbs).Value);
                 }
             }
+
+            return true;
         }
 
         private unsafe void BeginRenderPass()
diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineFull.cs b/src/Ryujinx.Graphics.Vulkan/PipelineFull.cs
index 24ca715fe9..a3e6818f3f 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineFull.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineFull.cs
@@ -246,7 +246,10 @@ namespace Ryujinx.Graphics.Vulkan
 
             SignalCommandBufferChange();
 
-            DynamicState.ReplayIfDirty(Gd.Api, CommandBuffer);
+            if (Pipeline != null && Pbp == PipelineBindPoint.Graphics)
+            {
+                DynamicState.ReplayIfDirty(Gd.Api, CommandBuffer);
+            }
         }
 
         public void FlushCommandsImpl()
diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineState.cs b/src/Ryujinx.Graphics.Vulkan/PipelineState.cs
index 11f5325108..25fd7168fb 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineState.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineState.cs
@@ -312,7 +312,6 @@ namespace Ryujinx.Graphics.Vulkan
         }
 
         public NativeArray<PipelineShaderStageCreateInfo> Stages;
-        public NativeArray<PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT> StageRequiredSubgroupSizes;
         public PipelineLayout PipelineLayout;
         public SpecData SpecializationData;
 
@@ -321,16 +320,6 @@ namespace Ryujinx.Graphics.Vulkan
         public void Initialize()
         {
             Stages = new NativeArray<PipelineShaderStageCreateInfo>(Constants.MaxShaderStages);
-            StageRequiredSubgroupSizes = new NativeArray<PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(Constants.MaxShaderStages);
-
-            for (int index = 0; index < Constants.MaxShaderStages; index++)
-            {
-                StageRequiredSubgroupSizes[index] = new PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT
-                {
-                    SType = StructureType.PipelineShaderStageRequiredSubgroupSizeCreateInfoExt,
-                    RequiredSubgroupSize = RequiredSubgroupSize,
-                };
-            }
 
             AdvancedBlendSrcPreMultiplied = true;
             AdvancedBlendDstPreMultiplied = true;
@@ -397,7 +386,8 @@ namespace Ryujinx.Graphics.Vulkan
             Device device,
             ShaderCollection program,
             PipelineCache cache,
-            RenderPass renderPass)
+            RenderPass renderPass,
+            bool throwOnError = false)
         {
             if (program.TryGetGraphicsPipeline(ref Internal, out var pipeline))
             {
@@ -630,7 +620,18 @@ namespace Ryujinx.Graphics.Vulkan
                     BasePipelineIndex = -1,
                 };
 
-                gd.Api.CreateGraphicsPipelines(device, cache, 1, &pipelineCreateInfo, null, &pipelineHandle).ThrowOnError();
+                Result result = gd.Api.CreateGraphicsPipelines(device, cache, 1, &pipelineCreateInfo, null, &pipelineHandle);
+
+                if (throwOnError)
+                {
+                    result.ThrowOnError();
+                }
+                else if (result.IsError())
+                {
+                    program.AddGraphicsPipeline(ref Internal, null);
+
+                    return null;
+                }
 
                 // Restore previous blend enable values if we changed it.
                 while (blendEnables != 0)
@@ -708,7 +709,6 @@ namespace Ryujinx.Graphics.Vulkan
         public readonly void Dispose()
         {
             Stages.Dispose();
-            StageRequiredSubgroupSizes.Dispose();
         }
     }
 }
diff --git a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
index 0d6da0391d..7c25c6d14e 100644
--- a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
+++ b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
@@ -374,7 +374,7 @@ namespace Ryujinx.Graphics.Vulkan
             pipeline.StagesCount = (uint)_shaders.Length;
             pipeline.PipelineLayout = PipelineLayout;
 
-            pipeline.CreateGraphicsPipeline(_gd, _device, this, (_gd.Pipeline as PipelineBase).PipelineCache, renderPass.Value);
+            pipeline.CreateGraphicsPipeline(_gd, _device, this, (_gd.Pipeline as PipelineBase).PipelineCache, renderPass.Value, throwOnError: true);
             pipeline.Dispose();
         }
 
@@ -511,7 +511,7 @@ namespace Ryujinx.Graphics.Vulkan
                 {
                     foreach (Auto<DisposablePipeline> pipeline in _graphicsPipelineCache.Values)
                     {
-                        pipeline.Dispose();
+                        pipeline?.Dispose();
                     }
                 }
 
diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanException.cs b/src/Ryujinx.Graphics.Vulkan/VulkanException.cs
index 0d4036802a..e203a3a216 100644
--- a/src/Ryujinx.Graphics.Vulkan/VulkanException.cs
+++ b/src/Ryujinx.Graphics.Vulkan/VulkanException.cs
@@ -6,10 +6,16 @@ namespace Ryujinx.Graphics.Vulkan
 {
     static class ResultExtensions
     {
+        public static bool IsError(this Result result)
+        {
+            // Only negative result codes are errors.
+            return result < Result.Success;
+        }
+
         public static void ThrowOnError(this Result result)
         {
             // Only negative result codes are errors.
-            if ((int)result < (int)Result.Success)
+            if (result.IsError())
             {
                 throw new VulkanException(result);
             }