From e5ad1dfa48590685fd93443a2adfd8568f6c1db0 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Mon, 14 Mar 2022 23:42:08 -0300
Subject: [PATCH] Implement S8D24 texture format and tweak depth range
 detection (#2458)

---
 Ryujinx.Graphics.GAL/Format.cs                |   4 +-
 .../Engine/Types/ZetaFormat.cs                |  14 +-
 Ryujinx.Graphics.Gpu/Image/FormatTable.cs     |   1 +
 .../Image/TextureCompatibility.cs             |   2 +-
 Ryujinx.Graphics.Gpu/Image/TexturePool.cs     |   2 +-
 Ryujinx.Graphics.OpenGL/FormatTable.cs        |  18 +--
 Ryujinx.Graphics.OpenGL/Framebuffer.cs        |   7 +-
 .../Image/FormatConverter.cs                  | 149 ++++++++++++++++++
 Ryujinx.Graphics.OpenGL/Image/TextureCopy.cs  |   6 +-
 Ryujinx.Graphics.OpenGL/Image/TextureView.cs  |  23 ++-
 10 files changed, 196 insertions(+), 30 deletions(-)
 create mode 100644 Ryujinx.Graphics.OpenGL/Image/FormatConverter.cs

diff --git a/Ryujinx.Graphics.GAL/Format.cs b/Ryujinx.Graphics.GAL/Format.cs
index a454413bf7..50cc6d40c0 100644
--- a/Ryujinx.Graphics.GAL/Format.cs
+++ b/Ryujinx.Graphics.GAL/Format.cs
@@ -52,7 +52,7 @@ namespace Ryujinx.Graphics.GAL
         R32G32B32A32Sint,
         S8Uint,
         D16Unorm,
-        D24X8Unorm,
+        S8UintD24Unorm,
         D32Float,
         D24UnormS8Uint,
         D32FloatS8Uint,
@@ -266,7 +266,7 @@ namespace Ryujinx.Graphics.GAL
             {
                 case Format.D16Unorm:
                 case Format.D24UnormS8Uint:
-                case Format.D24X8Unorm:
+                case Format.S8UintD24Unorm:
                 case Format.D32Float:
                 case Format.D32FloatS8Uint:
                 case Format.S8Uint:
diff --git a/Ryujinx.Graphics.Gpu/Engine/Types/ZetaFormat.cs b/Ryujinx.Graphics.Gpu/Engine/Types/ZetaFormat.cs
index 2de38fd20d..1de1621fc7 100644
--- a/Ryujinx.Graphics.Gpu/Engine/Types/ZetaFormat.cs
+++ b/Ryujinx.Graphics.Gpu/Engine/Types/ZetaFormat.cs
@@ -28,13 +28,13 @@ namespace Ryujinx.Graphics.Gpu.Engine.Types
         {
             return format switch
             {
-                ZetaFormat.D32Float          => new FormatInfo(Format.D32Float,          1, 1, 4,  1),
-                ZetaFormat.D16Unorm          => new FormatInfo(Format.D16Unorm,          1, 1, 2,  1),
-                ZetaFormat.D24UnormS8Uint    => new FormatInfo(Format.D24UnormS8Uint,    1, 1, 4,  2),
-                ZetaFormat.D24Unorm          => new FormatInfo(Format.D24UnormS8Uint,    1, 1, 4,  1),
-                ZetaFormat.S8UintD24Unorm    => new FormatInfo(Format.D24UnormS8Uint,    1, 1, 4,  2),
-                ZetaFormat.S8Uint            => new FormatInfo(Format.S8Uint,            1, 1, 1,  1),
-                ZetaFormat.D32FloatS8Uint    => new FormatInfo(Format.D32FloatS8Uint,    1, 1, 8,  2),
+                ZetaFormat.D32Float          => new FormatInfo(Format.D32Float,       1, 1, 4, 1),
+                ZetaFormat.D16Unorm          => new FormatInfo(Format.D16Unorm,       1, 1, 2, 1),
+                ZetaFormat.D24UnormS8Uint    => new FormatInfo(Format.D24UnormS8Uint, 1, 1, 4, 2),
+                ZetaFormat.D24Unorm          => new FormatInfo(Format.D24UnormS8Uint, 1, 1, 4, 1),
+                ZetaFormat.S8UintD24Unorm    => new FormatInfo(Format.S8UintD24Unorm, 1, 1, 4, 2),
+                ZetaFormat.S8Uint            => new FormatInfo(Format.S8Uint,         1, 1, 1, 1),
+                ZetaFormat.D32FloatS8Uint    => new FormatInfo(Format.D32FloatS8Uint, 1, 1, 8, 2),
                 _                            => FormatInfo.Default
             };
         }
diff --git a/Ryujinx.Graphics.Gpu/Image/FormatTable.cs b/Ryujinx.Graphics.Gpu/Image/FormatTable.cs
index 3c97e2e274..e5a4badc9d 100644
--- a/Ryujinx.Graphics.Gpu/Image/FormatTable.cs
+++ b/Ryujinx.Graphics.Gpu/Image/FormatTable.cs
@@ -55,6 +55,7 @@ namespace Ryujinx.Graphics.Gpu.Image
             { 0x24a0e, new FormatInfo(Format.D24UnormS8Uint,    1,  1,  4,  2) },
             { 0x24a29, new FormatInfo(Format.D24UnormS8Uint,    1,  1,  4,  2) },
             { 0x48a29, new FormatInfo(Format.D24UnormS8Uint,    1,  1,  4,  2) },
+            { 0x4912b, new FormatInfo(Format.S8UintD24Unorm,    1,  1,  4,  2) },
             { 0x25385, new FormatInfo(Format.D32FloatS8Uint,    1,  1,  8,  2) },
             { 0x253b0, new FormatInfo(Format.D32FloatS8Uint,    1,  1,  8,  2) },
             { 0xa4908, new FormatInfo(Format.R8G8B8A8Srgb,      1,  1,  4,  4) },
diff --git a/Ryujinx.Graphics.Gpu/Image/TextureCompatibility.cs b/Ryujinx.Graphics.Gpu/Image/TextureCompatibility.cs
index 188e1e090a..b798441f13 100644
--- a/Ryujinx.Graphics.Gpu/Image/TextureCompatibility.cs
+++ b/Ryujinx.Graphics.Gpu/Image/TextureCompatibility.cs
@@ -203,7 +203,7 @@ namespace Ryujinx.Graphics.Gpu.Image
                 }
 
                 if ((lhs.FormatInfo.Format == Format.D24UnormS8Uint ||
-                     lhs.FormatInfo.Format == Format.D24X8Unorm) && rhs.FormatInfo.Format == Format.B8G8R8A8Unorm)
+                     lhs.FormatInfo.Format == Format.S8UintD24Unorm) && rhs.FormatInfo.Format == Format.B8G8R8A8Unorm)
                 {
                     return TextureMatchQuality.FormatAlias;
                 }
diff --git a/Ryujinx.Graphics.Gpu/Image/TexturePool.cs b/Ryujinx.Graphics.Gpu/Image/TexturePool.cs
index f936615f10..10a6ff82af 100644
--- a/Ryujinx.Graphics.Gpu/Image/TexturePool.cs
+++ b/Ryujinx.Graphics.Gpu/Image/TexturePool.cs
@@ -362,7 +362,7 @@ namespace Ryujinx.Graphics.Gpu.Image
                 return DepthStencilMode.Depth;
             }
 
-            if (format == Format.D24X8Unorm || format == Format.D24UnormS8Uint)
+            if (format == Format.D24UnormS8Uint)
             {
                 return component == SwizzleComponent.Red
                     ? DepthStencilMode.Stencil
diff --git a/Ryujinx.Graphics.OpenGL/FormatTable.cs b/Ryujinx.Graphics.OpenGL/FormatTable.cs
index c054ba57a9..1a739b5ce2 100644
--- a/Ryujinx.Graphics.OpenGL/FormatTable.cs
+++ b/Ryujinx.Graphics.OpenGL/FormatTable.cs
@@ -6,15 +6,15 @@ namespace Ryujinx.Graphics.OpenGL
 {
     struct FormatTable
     {
-        private static FormatInfo[] Table;
-        private static SizedInternalFormat[] TableImage;
+        private static FormatInfo[] _table;
+        private static SizedInternalFormat[] _tableImage;
 
         static FormatTable()
         {
             int tableSize = Enum.GetNames<Format>().Length;
 
-            Table = new FormatInfo[tableSize];
-            TableImage = new SizedInternalFormat[tableSize];
+            _table = new FormatInfo[tableSize];
+            _tableImage = new SizedInternalFormat[tableSize];
 
             Add(Format.R8Unorm,             new FormatInfo(1, true,  false, All.R8,                PixelFormat.Red,            PixelType.UnsignedByte));
             Add(Format.R8Snorm,             new FormatInfo(1, true,  false, All.R8Snorm,           PixelFormat.Red,            PixelType.Byte));
@@ -66,7 +66,7 @@ namespace Ryujinx.Graphics.OpenGL
             Add(Format.R32G32B32A32Sint,    new FormatInfo(4, false, false, All.Rgba32i,           PixelFormat.RgbaInteger,    PixelType.Int));
             Add(Format.S8Uint,              new FormatInfo(1, false, false, All.StencilIndex8,     PixelFormat.StencilIndex,   PixelType.UnsignedByte));
             Add(Format.D16Unorm,            new FormatInfo(1, false, false, All.DepthComponent16,  PixelFormat.DepthComponent, PixelType.UnsignedShort));
-            Add(Format.D24X8Unorm,          new FormatInfo(1, false, false, All.DepthComponent24,  PixelFormat.DepthComponent, PixelType.UnsignedInt));
+            Add(Format.S8UintD24Unorm,      new FormatInfo(1, false, false, All.Depth24Stencil8,   PixelFormat.DepthStencil,   PixelType.UnsignedInt248));
             Add(Format.D32Float,            new FormatInfo(1, false, false, All.DepthComponent32f, PixelFormat.DepthComponent, PixelType.Float));
             Add(Format.D24UnormS8Uint,      new FormatInfo(1, false, false, All.Depth24Stencil8,   PixelFormat.DepthStencil,   PixelType.UnsignedInt248));
             Add(Format.D32FloatS8Uint,      new FormatInfo(1, false, false, All.Depth32fStencil8,  PixelFormat.DepthStencil,   PixelType.Float32UnsignedInt248Rev));
@@ -218,22 +218,22 @@ namespace Ryujinx.Graphics.OpenGL
 
         private static void Add(Format format, FormatInfo info)
         {
-            Table[(int)format] = info;
+            _table[(int)format] = info;
         }
 
         private static void Add(Format format, SizedInternalFormat sif)
         {
-            TableImage[(int)format] = sif;
+            _tableImage[(int)format] = sif;
         }
 
         public static FormatInfo GetFormatInfo(Format format)
         {
-            return Table[(int)format];
+            return _table[(int)format];
         }
 
         public static SizedInternalFormat GetImageFormat(Format format)
         {
-            return TableImage[(int)format];
+            return _tableImage[(int)format];
         }
     }
 }
diff --git a/Ryujinx.Graphics.OpenGL/Framebuffer.cs b/Ryujinx.Graphics.OpenGL/Framebuffer.cs
index 76b321b7a9..da928b4c87 100644
--- a/Ryujinx.Graphics.OpenGL/Framebuffer.cs
+++ b/Ryujinx.Graphics.OpenGL/Framebuffer.cs
@@ -127,14 +127,13 @@ namespace Ryujinx.Graphics.OpenGL
         private static bool IsPackedDepthStencilFormat(Format format)
         {
             return format == Format.D24UnormS8Uint ||
-                   format == Format.D32FloatS8Uint;
+                   format == Format.D32FloatS8Uint ||
+                   format == Format.S8UintD24Unorm;
         }
 
         private static bool IsDepthOnlyFormat(Format format)
         {
-            return format == Format.D16Unorm ||
-                   format == Format.D24X8Unorm ||
-                   format == Format.D32Float;
+            return format == Format.D16Unorm || format == Format.D32Float;
         }
 
         public void Dispose()
diff --git a/Ryujinx.Graphics.OpenGL/Image/FormatConverter.cs b/Ryujinx.Graphics.OpenGL/Image/FormatConverter.cs
new file mode 100644
index 0000000000..c4bbf74566
--- /dev/null
+++ b/Ryujinx.Graphics.OpenGL/Image/FormatConverter.cs
@@ -0,0 +1,149 @@
+using System;
+using System.Numerics;
+using System.Runtime.InteropServices;
+using System.Runtime.Intrinsics;
+using System.Runtime.Intrinsics.X86;
+
+namespace Ryujinx.Graphics.OpenGL.Image
+{
+    static class FormatConverter
+    {
+        public unsafe static byte[] ConvertS8D24ToD24S8(ReadOnlySpan<byte> data)
+        {
+            byte[] output = new byte[data.Length];
+
+            int start = 0;
+
+            if (Avx2.IsSupported)
+            {
+                var mask = Vector256.Create(
+                    (byte)3, (byte)0, (byte)1, (byte)2,
+                    (byte)7, (byte)4, (byte)5, (byte)6,
+                    (byte)11, (byte)8, (byte)9, (byte)10,
+                    (byte)15, (byte)12, (byte)13, (byte)14,
+                    (byte)19, (byte)16, (byte)17, (byte)18,
+                    (byte)23, (byte)20, (byte)21, (byte)22,
+                    (byte)27, (byte)24, (byte)25, (byte)26,
+                    (byte)31, (byte)28, (byte)29, (byte)30);
+
+                int sizeAligned = data.Length & ~31;
+
+                fixed (byte* pInput = data, pOutput = output)
+                {
+                    for (uint i = 0; i < sizeAligned; i += 32)
+                    {
+                        var dataVec = Avx.LoadVector256(pInput + i);
+
+                        dataVec = Avx2.Shuffle(dataVec, mask);
+
+                        Avx.Store(pOutput + i, dataVec);
+                    }
+                }
+
+                start = sizeAligned;
+            }
+            else if (Ssse3.IsSupported)
+            {
+                var mask = Vector128.Create(
+                    (byte)3, (byte)0, (byte)1, (byte)2,
+                    (byte)7, (byte)4, (byte)5, (byte)6,
+                    (byte)11, (byte)8, (byte)9, (byte)10,
+                    (byte)15, (byte)12, (byte)13, (byte)14);
+
+                int sizeAligned = data.Length & ~15;
+
+                fixed (byte* pInput = data, pOutput = output)
+                {
+                    for (uint i = 0; i < sizeAligned; i += 16)
+                    {
+                        var dataVec = Sse2.LoadVector128(pInput + i);
+
+                        dataVec = Ssse3.Shuffle(dataVec, mask);
+
+                        Sse2.Store(pOutput + i, dataVec);
+                    }
+                }
+
+                start = sizeAligned;
+            }
+
+            var outSpan = MemoryMarshal.Cast<byte, uint>(output);
+            var dataSpan = MemoryMarshal.Cast<byte, uint>(data);
+            for (int i = start / sizeof(uint); i < dataSpan.Length; i++)
+            {
+                outSpan[i] = BitOperations.RotateLeft(dataSpan[i], 8);
+            }
+
+            return output;
+        }
+
+        public unsafe static byte[] ConvertD24S8ToS8D24(ReadOnlySpan<byte> data)
+        {
+            byte[] output = new byte[data.Length];
+
+            int start = 0;
+
+            if (Avx2.IsSupported)
+            {
+                var mask = Vector256.Create(
+                    (byte)1, (byte)2, (byte)3, (byte)0,
+                    (byte)5, (byte)6, (byte)7, (byte)4,
+                    (byte)9, (byte)10, (byte)11, (byte)8,
+                    (byte)13, (byte)14, (byte)15, (byte)12,
+                    (byte)17, (byte)18, (byte)19, (byte)16,
+                    (byte)21, (byte)22, (byte)23, (byte)20,
+                    (byte)25, (byte)26, (byte)27, (byte)24,
+                    (byte)29, (byte)30, (byte)31, (byte)28);
+
+                int sizeAligned = data.Length & ~31;
+
+                fixed (byte* pInput = data, pOutput = output)
+                {
+                    for (uint i = 0; i < sizeAligned; i += 32)
+                    {
+                        var dataVec = Avx.LoadVector256(pInput + i);
+
+                        dataVec = Avx2.Shuffle(dataVec, mask);
+
+                        Avx.Store(pOutput + i, dataVec);
+                    }
+                }
+
+                start = sizeAligned;
+            }
+            else if (Ssse3.IsSupported)
+            {
+                var mask = Vector128.Create(
+                    (byte)1, (byte)2, (byte)3, (byte)0,
+                    (byte)5, (byte)6, (byte)7, (byte)4,
+                    (byte)9, (byte)10, (byte)11, (byte)8,
+                    (byte)13, (byte)14, (byte)15, (byte)12);
+
+                int sizeAligned = data.Length & ~15;
+
+                fixed (byte* pInput = data, pOutput = output)
+                {
+                    for (uint i = 0; i < sizeAligned; i += 16)
+                    {
+                        var dataVec = Sse2.LoadVector128(pInput + i);
+
+                        dataVec = Ssse3.Shuffle(dataVec, mask);
+
+                        Sse2.Store(pOutput + i, dataVec);
+                    }
+                }
+
+                start = sizeAligned;
+            }
+
+            var outSpan = MemoryMarshal.Cast<byte, uint>(output);
+            var dataSpan = MemoryMarshal.Cast<byte, uint>(data);
+            for (int i = start / sizeof(uint); i < dataSpan.Length; i++)
+            {
+                outSpan[i] = BitOperations.RotateRight(dataSpan[i], 8);
+            }
+
+            return output;
+        }
+    }
+}
diff --git a/Ryujinx.Graphics.OpenGL/Image/TextureCopy.cs b/Ryujinx.Graphics.OpenGL/Image/TextureCopy.cs
index 7811d02156..9be8656139 100644
--- a/Ryujinx.Graphics.OpenGL/Image/TextureCopy.cs
+++ b/Ryujinx.Graphics.OpenGL/Image/TextureCopy.cs
@@ -291,7 +291,7 @@ namespace Ryujinx.Graphics.OpenGL.Image
 
         private static ClearBufferMask GetMask(Format format)
         {
-            if (format == Format.D24UnormS8Uint || format == Format.D32FloatS8Uint)
+            if (format == Format.D24UnormS8Uint || format == Format.D32FloatS8Uint || format == Format.S8UintD24Unorm)
             {
                 return ClearBufferMask.DepthBufferBit | ClearBufferMask.StencilBufferBit;
             }
@@ -311,9 +311,7 @@ namespace Ryujinx.Graphics.OpenGL.Image
 
         private static bool IsDepthOnly(Format format)
         {
-            return format == Format.D16Unorm   ||
-                   format == Format.D24X8Unorm ||
-                   format == Format.D32Float;
+            return format == Format.D16Unorm || format == Format.D32Float;
         }
 
         public TextureView BgraSwap(TextureView from)
diff --git a/Ryujinx.Graphics.OpenGL/Image/TextureView.cs b/Ryujinx.Graphics.OpenGL/Image/TextureView.cs
index f03653c43a..909a06200d 100644
--- a/Ryujinx.Graphics.OpenGL/Image/TextureView.cs
+++ b/Ryujinx.Graphics.OpenGL/Image/TextureView.cs
@@ -140,9 +140,11 @@ namespace Ryujinx.Graphics.OpenGL.Image
                 size += Info.GetMipSize(level);
             }
 
+            ReadOnlySpan<byte> data;
+
             if (HwCapabilities.UsePersistentBufferForFlush)
             {
-                return _renderer.PersistentBuffers.Default.GetTextureData(this, size);
+                data = _renderer.PersistentBuffers.Default.GetTextureData(this, size);
             }
             else
             {
@@ -150,8 +152,15 @@ namespace Ryujinx.Graphics.OpenGL.Image
 
                 WriteTo(target);
 
-                return new ReadOnlySpan<byte>(target.ToPointer(), size);
+                data = new ReadOnlySpan<byte>(target.ToPointer(), size);
             }
+
+            if (Format == Format.S8UintD24Unorm)
+            {
+                data = FormatConverter.ConvertD24S8ToS8D24(data);
+            }
+
+            return data;
         }
 
         public unsafe ReadOnlySpan<byte> GetData(int layer, int level)
@@ -285,6 +294,11 @@ namespace Ryujinx.Graphics.OpenGL.Image
 
         public void SetData(ReadOnlySpan<byte> data)
         {
+            if (Format == Format.S8UintD24Unorm)
+            {
+                data = FormatConverter.ConvertS8D24ToD24S8(data);
+            }
+
             unsafe
             {
                 fixed (byte* ptr = data)
@@ -296,6 +310,11 @@ namespace Ryujinx.Graphics.OpenGL.Image
 
         public void SetData(ReadOnlySpan<byte> data, int layer, int level)
         {
+            if (Format == Format.S8UintD24Unorm)
+            {
+                data = FormatConverter.ConvertS8D24ToD24S8(data);
+            }
+
             unsafe
             {
                 fixed (byte* ptr = data)