From 331c07807fd0db5d4452d6ef02962a6d19a56d7f Mon Sep 17 00:00:00 2001
From: riperiperi <rhy3756547@hotmail.com>
Date: Sat, 20 Jan 2024 14:07:33 +0000
Subject: [PATCH] Vulkan: Use templates for descriptor updates (#6014)

* WIP: Descriptor template update

* Make configurable

* Wording

* Simplify template creation

* Whitespace

* UTF-8 whatever

* Leave only templated path, better template updater
---
 .../DescriptorSetTemplate.cs                  | 145 ++++++++++++++++++
 .../DescriptorSetTemplateUpdater.cs           |  65 ++++++++
 .../DescriptorSetUpdater.cs                   |  28 +++-
 src/Ryujinx.Graphics.Vulkan/PipelineBase.cs   |   2 +-
 .../ShaderCollection.cs                       |  24 +++
 5 files changed, 256 insertions(+), 8 deletions(-)
 create mode 100644 src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
 create mode 100644 src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs

diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
new file mode 100644
index 0000000000..0c0004b95c
--- /dev/null
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
@@ -0,0 +1,145 @@
+using Ryujinx.Graphics.GAL;
+using Silk.NET.Vulkan;
+using System;
+using System.Runtime.CompilerServices;
+
+namespace Ryujinx.Graphics.Vulkan
+{
+    class DescriptorSetTemplate : IDisposable
+    {
+        private readonly VulkanRenderer _gd;
+        private readonly Device _device;
+
+        public readonly DescriptorUpdateTemplate Template;
+        public readonly int Size;
+
+        public unsafe DescriptorSetTemplate(VulkanRenderer gd, Device device, ResourceBindingSegment[] segments, PipelineLayoutCacheEntry plce, PipelineBindPoint pbp, int setIndex)
+        {
+            _gd = gd;
+            _device = device;
+
+            // Create a template from the set usages. Assumes the descriptor set is updated in segment order then binding order.
+
+            DescriptorUpdateTemplateEntry* entries = stackalloc DescriptorUpdateTemplateEntry[segments.Length];
+            nuint structureOffset = 0;
+
+            for (int seg = 0; seg < segments.Length; seg++)
+            {
+                ResourceBindingSegment segment = segments[seg];
+
+                int binding = segment.Binding;
+                int count = segment.Count;
+
+                if (setIndex == PipelineBase.UniformSetIndex)
+                {
+                    entries[seg] = new DescriptorUpdateTemplateEntry()
+                    {
+                        DescriptorType = DescriptorType.UniformBuffer,
+                        DstBinding = (uint)binding,
+                        DescriptorCount = (uint)count,
+                        Offset = structureOffset,
+                        Stride = (nuint)Unsafe.SizeOf<DescriptorBufferInfo>()
+                    };
+
+                    structureOffset += (nuint)(Unsafe.SizeOf<DescriptorBufferInfo>() * count);
+                }
+                else if (setIndex == PipelineBase.StorageSetIndex)
+                {
+                    entries[seg] = new DescriptorUpdateTemplateEntry()
+                    {
+                        DescriptorType = DescriptorType.StorageBuffer,
+                        DstBinding = (uint)binding,
+                        DescriptorCount = (uint)count,
+                        Offset = structureOffset,
+                        Stride = (nuint)Unsafe.SizeOf<DescriptorBufferInfo>()
+                    };
+
+                    structureOffset += (nuint)(Unsafe.SizeOf<DescriptorBufferInfo>() * count);
+                }
+                else if (setIndex == PipelineBase.TextureSetIndex)
+                {
+                    if (segment.Type != ResourceType.BufferTexture)
+                    {
+                        entries[seg] = new DescriptorUpdateTemplateEntry()
+                        {
+                            DescriptorType = DescriptorType.CombinedImageSampler,
+                            DstBinding = (uint)binding,
+                            DescriptorCount = (uint)count,
+                            Offset = structureOffset,
+                            Stride = (nuint)Unsafe.SizeOf<DescriptorImageInfo>()
+                        };
+
+                        structureOffset += (nuint)(Unsafe.SizeOf<DescriptorImageInfo>() * count);
+                    }
+                    else
+                    {
+                        entries[seg] = new DescriptorUpdateTemplateEntry()
+                        {
+                            DescriptorType = DescriptorType.UniformTexelBuffer,
+                            DstBinding = (uint)binding,
+                            DescriptorCount = (uint)count,
+                            Offset = structureOffset,
+                            Stride = (nuint)Unsafe.SizeOf<BufferView>()
+                        };
+
+                        structureOffset += (nuint)(Unsafe.SizeOf<BufferView>() * count);
+                    }
+                }
+                else if (setIndex == PipelineBase.ImageSetIndex)
+                {
+                    if (segment.Type != ResourceType.BufferImage)
+                    {
+                        entries[seg] = new DescriptorUpdateTemplateEntry()
+                        {
+                            DescriptorType = DescriptorType.StorageImage,
+                            DstBinding = (uint)binding,
+                            DescriptorCount = (uint)count,
+                            Offset = structureOffset,
+                            Stride = (nuint)Unsafe.SizeOf<DescriptorImageInfo>()
+                        };
+
+                        structureOffset += (nuint)(Unsafe.SizeOf<DescriptorImageInfo>() * count);
+                    }
+                    else
+                    {
+                        entries[seg] = new DescriptorUpdateTemplateEntry()
+                        {
+                            DescriptorType = DescriptorType.StorageTexelBuffer,
+                            DstBinding = (uint)binding,
+                            DescriptorCount = (uint)count,
+                            Offset = structureOffset,
+                            Stride = (nuint)Unsafe.SizeOf<BufferView>()
+                        };
+
+                        structureOffset += (nuint)(Unsafe.SizeOf<BufferView>() * count);
+                    }
+                }
+            }
+
+            Size = (int)structureOffset;
+
+            var info = new DescriptorUpdateTemplateCreateInfo()
+            {
+                SType = StructureType.DescriptorUpdateTemplateCreateInfo,
+                DescriptorUpdateEntryCount = (uint)segments.Length,
+                PDescriptorUpdateEntries = entries,
+
+                TemplateType = DescriptorUpdateTemplateType.DescriptorSet,
+                DescriptorSetLayout = plce.DescriptorSetLayouts[setIndex],
+                PipelineBindPoint = pbp,
+                PipelineLayout = plce.PipelineLayout,
+                Set = (uint)setIndex,
+            };
+
+            DescriptorUpdateTemplate result;
+            gd.Api.CreateDescriptorUpdateTemplate(device, &info, null, &result).ThrowOnError();
+
+            Template = result;
+        }
+
+        public unsafe void Dispose()
+        {
+            _gd.Api.DestroyDescriptorUpdateTemplate(_device, Template, null);
+        }
+    }
+}
diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs
new file mode 100644
index 0000000000..1eb9dce75c
--- /dev/null
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs
@@ -0,0 +1,65 @@
+using Ryujinx.Common;
+using Silk.NET.Vulkan;
+using System;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace Ryujinx.Graphics.Vulkan
+{
+    ref struct DescriptorSetTemplateWriter
+    {
+        private Span<byte> _data;
+
+        public DescriptorSetTemplateWriter(Span<byte> data)
+        {
+            _data = data;
+        }
+
+        public void Push<T>(ReadOnlySpan<T> values) where T : unmanaged
+        {
+            Span<T> target = MemoryMarshal.Cast<byte, T>(_data);
+
+            values.CopyTo(target);
+
+            _data = _data[(Unsafe.SizeOf<T>() * values.Length)..];
+        }
+    }
+
+    unsafe class DescriptorSetTemplateUpdater : IDisposable
+    {
+        private const int SizeGranularity = 512;
+
+        private DescriptorSetTemplate _activeTemplate;
+        private NativeArray<byte> _data;
+
+        private void EnsureSize(int size)
+        {
+            if (_data == null || _data.Length < size)
+            {
+                _data?.Dispose();
+
+                int dataSize = BitUtils.AlignUp(size, SizeGranularity);
+                _data = new NativeArray<byte>(dataSize);
+            }
+        }
+
+        public DescriptorSetTemplateWriter Begin(DescriptorSetTemplate template)
+        {
+            _activeTemplate = template;
+
+            EnsureSize(template.Size);
+
+            return new DescriptorSetTemplateWriter(new Span<byte>(_data.Pointer, template.Size));
+        }
+
+        public void Commit(VulkanRenderer gd, Device device, DescriptorSet set)
+        {
+            gd.Api.UpdateDescriptorSetWithTemplate(device, set, _activeTemplate.Template, _data.Pointer);
+        }
+
+        public void Dispose()
+        {
+            _data?.Dispose();
+        }
+    }
+}
diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
index e0fe5d89a4..6615d8ce0e 100644
--- a/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
@@ -35,6 +35,7 @@ namespace Ryujinx.Graphics.Vulkan
         }
 
         private readonly VulkanRenderer _gd;
+        private readonly Device _device;
         private readonly PipelineBase _pipeline;
         private ShaderCollection _program;
 
@@ -54,6 +55,8 @@ namespace Ryujinx.Graphics.Vulkan
         private readonly BufferView[] _bufferTextures;
         private readonly BufferView[] _bufferImages;
 
+        private readonly DescriptorSetTemplateUpdater _templateUpdater;
+
         private BitMapStruct<Array2<long>> _uniformSet;
         private BitMapStruct<Array2<long>> _storageSet;
         private BitMapStruct<Array2<long>> _uniformMirrored;
@@ -78,9 +81,10 @@ namespace Ryujinx.Graphics.Vulkan
         private readonly TextureView _dummyTexture;
         private readonly SamplerHolder _dummySampler;
 
-        public DescriptorSetUpdater(VulkanRenderer gd, PipelineBase pipeline)
+        public DescriptorSetUpdater(VulkanRenderer gd, Device device, PipelineBase pipeline)
         {
             _gd = gd;
+            _device = device;
             _pipeline = pipeline;
 
             // Some of the bindings counts needs to be multiplied by 2 because we have buffer and
@@ -152,6 +156,8 @@ namespace Ryujinx.Graphics.Vulkan
                 0,
                 0,
                 1f));
+
+            _templateUpdater = new();
         }
 
         public void Initialize()
@@ -509,6 +515,10 @@ namespace Ryujinx.Graphics.Vulkan
                 }
             }
 
+            DescriptorSetTemplate template = program.Templates[setIndex];
+
+            DescriptorSetTemplateWriter tu = _templateUpdater.Begin(template);
+
             foreach (ResourceBindingSegment segment in bindingSegments)
             {
                 int binding = segment.Binding;
@@ -531,7 +541,8 @@ namespace Ryujinx.Graphics.Vulkan
                     }
 
                     ReadOnlySpan<DescriptorBufferInfo> uniformBuffers = _uniformBuffers;
-                    dsc.UpdateBuffers(0, binding, uniformBuffers.Slice(binding, count), DescriptorType.UniformBuffer);
+
+                    tu.Push(uniformBuffers.Slice(binding, count));
                 }
                 else if (setIndex == PipelineBase.StorageSetIndex)
                 {
@@ -556,7 +567,8 @@ namespace Ryujinx.Graphics.Vulkan
                     }
 
                     ReadOnlySpan<DescriptorBufferInfo> storageBuffers = _storageBuffers;
-                    dsc.UpdateBuffers(0, binding, storageBuffers.Slice(binding, count), DescriptorType.StorageBuffer);
+
+                    tu.Push(storageBuffers.Slice(binding, count));
                 }
                 else if (setIndex == PipelineBase.TextureSetIndex)
                 {
@@ -582,7 +594,7 @@ namespace Ryujinx.Graphics.Vulkan
                             }
                         }
 
-                        dsc.UpdateImages(0, binding, textures[..count], DescriptorType.CombinedImageSampler);
+                        tu.Push<DescriptorImageInfo>(textures[..count]);
                     }
                     else
                     {
@@ -593,7 +605,7 @@ namespace Ryujinx.Graphics.Vulkan
                             bufferTextures[i] = _bufferTextureRefs[binding + i]?.GetBufferView(cbs, false) ?? default;
                         }
 
-                        dsc.UpdateBufferImages(0, binding, bufferTextures[..count], DescriptorType.UniformTexelBuffer);
+                        tu.Push<BufferView>(bufferTextures[..count]);
                     }
                 }
                 else if (setIndex == PipelineBase.ImageSetIndex)
@@ -607,7 +619,7 @@ namespace Ryujinx.Graphics.Vulkan
                             images[i].ImageView = _imageRefs[binding + i]?.Get(cbs).Value ?? default;
                         }
 
-                        dsc.UpdateImages(0, binding, images[..count], DescriptorType.StorageImage);
+                        tu.Push<DescriptorImageInfo>(images[..count]);
                     }
                     else
                     {
@@ -618,12 +630,13 @@ namespace Ryujinx.Graphics.Vulkan
                             bufferImages[i] = _bufferImageRefs[binding + i]?.GetBufferView(cbs, _bufferImageFormats[binding + i], true) ?? default;
                         }
 
-                        dsc.UpdateBufferImages(0, binding, bufferImages[..count], DescriptorType.StorageTexelBuffer);
+                        tu.Push<BufferView>(bufferImages[..count]);
                     }
                 }
             }
 
             var sets = dsc.GetSets();
+            _templateUpdater.Commit(_gd, _device, sets[0]);
 
             _gd.Api.CmdBindDescriptorSets(cbs.CommandBuffer, pbp, _program.PipelineLayout, (uint)setIndex, 1, sets, 0, ReadOnlySpan<uint>.Empty);
         }
@@ -736,6 +749,7 @@ namespace Ryujinx.Graphics.Vulkan
             {
                 _dummyTexture.Dispose();
                 _dummySampler.Dispose();
+                _templateUpdater.Dispose();
             }
         }
 
diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
index b05dd1a699..af3a27e556 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
@@ -102,7 +102,7 @@ namespace Ryujinx.Graphics.Vulkan
 
             gd.Api.CreatePipelineCache(device, pipelineCacheCreateInfo, null, out PipelineCache).ThrowOnError();
 
-            _descriptorSetUpdater = new DescriptorSetUpdater(gd, this);
+            _descriptorSetUpdater = new DescriptorSetUpdater(gd, device, this);
             _vertexBufferUpdater = new VertexBufferUpdater(gd);
 
             _transformFeedbackBuffers = new BufferState[Constants.MaxTransformFeedbackBuffers];
diff --git a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
index d01eebf3af..0d6da0391d 100644
--- a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
+++ b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
@@ -26,6 +26,7 @@ namespace Ryujinx.Graphics.Vulkan
 
         public ResourceBindingSegment[][] ClearSegments { get; }
         public ResourceBindingSegment[][] BindingSegments { get; }
+        public DescriptorSetTemplate[] Templates { get; }
 
         public ProgramLinkStatus LinkStatus { get; private set; }
 
@@ -118,6 +119,7 @@ namespace Ryujinx.Graphics.Vulkan
 
             ClearSegments = BuildClearSegments(resourceLayout.Sets);
             BindingSegments = BuildBindingSegments(resourceLayout.SetUsages);
+            Templates = BuildTemplates();
 
             _compileTask = Task.CompletedTask;
             _firstBackgroundUse = false;
@@ -241,6 +243,23 @@ namespace Ryujinx.Graphics.Vulkan
             return segments;
         }
 
+        private DescriptorSetTemplate[] BuildTemplates()
+        {
+            var templates = new DescriptorSetTemplate[BindingSegments.Length];
+
+            for (int setIndex = 0; setIndex < BindingSegments.Length; setIndex++)
+            {
+                ResourceBindingSegment[] segments = BindingSegments[setIndex];
+
+                if (segments != null && segments.Length > 0)
+                {
+                    templates[setIndex] = new DescriptorSetTemplate(_gd, _device, segments, _plce, IsCompute ? PipelineBindPoint.Compute : PipelineBindPoint.Graphics, setIndex);
+                }
+            }
+
+            return templates;
+        }
+
         private async Task BackgroundCompilation()
         {
             await Task.WhenAll(_shaders.Select(shader => shader.CompileTask));
@@ -504,6 +523,11 @@ namespace Ryujinx.Graphics.Vulkan
                     }
                 }
 
+                for (int i = 0; i < Templates.Length; i++)
+                {
+                    Templates[i]?.Dispose();
+                }
+
                 if (_dummyRenderPass.Value.Handle != 0)
                 {
                     _dummyRenderPass.Dispose();