From 09593ae6d85b204b7bded803a56f0d81fbda1127 Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Mon, 13 Aug 2018 18:22:09 -0300
Subject: [PATCH] Add partial support to the TEX.B shader instruction (#342)

* Add partial support to the TEX.B shader instruction, fix for mixed indexed and non-indexed drawing

* Better exception
---
 Ryujinx.Graphics/Gal/Shader/GlslDecl.cs       | 74 +++++++++++++++----
 Ryujinx.Graphics/Gal/Shader/GlslDecompiler.cs | 45 ++++++++++-
 .../Gal/Shader/ShaderDecodeMem.cs             | 26 +++++--
 Ryujinx.Graphics/Gal/Shader/ShaderIrInst.cs   |  3 +-
 .../Gal/Shader/ShaderOpCodeTable.cs           |  1 +
 Ryujinx.Graphics/Gal/ShaderDeclInfo.cs        | 15 +++-
 Ryujinx.HLE/Gpu/Engines/NvGpuEngine3d.cs      | 31 +++++---
 7 files changed, 153 insertions(+), 42 deletions(-)

diff --git a/Ryujinx.Graphics/Gal/Shader/GlslDecl.cs b/Ryujinx.Graphics/Gal/Shader/GlslDecl.cs
index a0c747bac4..691ab80073 100644
--- a/Ryujinx.Graphics/Gal/Shader/GlslDecl.cs
+++ b/Ryujinx.Graphics/Gal/Shader/GlslDecl.cs
@@ -1,3 +1,4 @@
+using System;
 using System.Collections.Generic;
 
 namespace Ryujinx.Graphics.Gal.Shader
@@ -19,7 +20,7 @@ namespace Ryujinx.Graphics.Gal.Shader
         public const int PositionOutAttrLocation = 15;
 
         private const int AttrStartIndex = 8;
-        private const int TexStartIndex = 8;
+        private const int TexStartIndex  = 8;
 
         public const string PositionOutAttrName = "position";
 
@@ -46,6 +47,8 @@ namespace Ryujinx.Graphics.Gal.Shader
 
         private string StagePrefix;
 
+        private Dictionary<ShaderIrOp, ShaderDeclInfo> m_CbTextures;
+
         private Dictionary<int, ShaderDeclInfo> m_Textures;
         private Dictionary<int, ShaderDeclInfo> m_Uniforms;
 
@@ -56,6 +59,8 @@ namespace Ryujinx.Graphics.Gal.Shader
         private Dictionary<int, ShaderDeclInfo> m_Gprs;
         private Dictionary<int, ShaderDeclInfo> m_Preds;
 
+        public IReadOnlyDictionary<ShaderIrOp, ShaderDeclInfo> CbTextures => m_CbTextures;
+
         public IReadOnlyDictionary<int, ShaderDeclInfo> Textures => m_Textures;
         public IReadOnlyDictionary<int, ShaderDeclInfo> Uniforms => m_Uniforms;
 
@@ -72,8 +77,10 @@ namespace Ryujinx.Graphics.Gal.Shader
         {
             this.ShaderType = ShaderType;
 
-            m_Uniforms = new Dictionary<int, ShaderDeclInfo>();
+            m_CbTextures = new Dictionary<ShaderIrOp, ShaderDeclInfo>();
+
             m_Textures = new Dictionary<int, ShaderDeclInfo>();
+            m_Uniforms = new Dictionary<int, ShaderDeclInfo>();
 
             m_Attributes    = new Dictionary<int, ShaderDeclInfo>();
             m_InAttributes  = new Dictionary<int, ShaderDeclInfo>();
@@ -89,14 +96,16 @@ namespace Ryujinx.Graphics.Gal.Shader
 
             if (ShaderType == GalShaderType.Fragment)
             {
-                m_Gprs.Add(0, new ShaderDeclInfo(FragmentOutputName, 0, 0, 4));
+                m_Gprs.Add(0, new ShaderDeclInfo(FragmentOutputName, 0, false, 0, 4));
             }
 
             foreach (ShaderIrBlock Block in Blocks)
             {
-                foreach (ShaderIrNode Node in Block.GetNodes())
+                ShaderIrNode[] Nodes = Block.GetNodes();
+
+                foreach (ShaderIrNode Node in Nodes)
                 {
-                    Traverse(null, Node);
+                    Traverse(Nodes, null, Node);
                 }
             }
         }
@@ -152,31 +161,31 @@ namespace Ryujinx.Graphics.Gal.Shader
             }
         }
 
-        private void Traverse(ShaderIrNode Parent, ShaderIrNode Node)
+        private void Traverse(ShaderIrNode[] Nodes, ShaderIrNode Parent, ShaderIrNode Node)
         {
             switch (Node)
             {
                 case ShaderIrAsg Asg:
                 {
-                    Traverse(Asg, Asg.Dst);
-                    Traverse(Asg, Asg.Src);
+                    Traverse(Nodes, Asg, Asg.Dst);
+                    Traverse(Nodes, Asg, Asg.Src);
 
                     break;
                 }
 
                 case ShaderIrCond Cond:
                 {
-                    Traverse(Cond, Cond.Pred);
-                    Traverse(Cond, Cond.Child);
+                    Traverse(Nodes, Cond, Cond.Pred);
+                    Traverse(Nodes, Cond, Cond.Child);
 
                     break;
                 }
 
                 case ShaderIrOp Op:
                 {
-                    Traverse(Op, Op.OperandA);
-                    Traverse(Op, Op.OperandB);
-                    Traverse(Op, Op.OperandC);
+                    Traverse(Nodes, Op, Op.OperandA);
+                    Traverse(Nodes, Op, Op.OperandB);
+                    Traverse(Nodes, Op, Op.OperandC);
 
                     if (Op.Inst == ShaderIrInst.Texq ||
                         Op.Inst == ShaderIrInst.Texs ||
@@ -190,6 +199,38 @@ namespace Ryujinx.Graphics.Gal.Shader
 
                         m_Textures.TryAdd(Handle, new ShaderDeclInfo(Name, Handle));
                     }
+                    else if (Op.Inst == ShaderIrInst.Texb)
+                    {
+                        ShaderIrNode HandleSrc = null;
+
+                        int Index = Array.IndexOf(Nodes, Parent) - 1;
+
+                        for (; Index >= 0; Index--)
+                        {
+                            ShaderIrNode Curr = Nodes[Index];
+
+                            if (Curr is ShaderIrAsg Asg && Asg.Dst is ShaderIrOperGpr Gpr)
+                            {
+                                if (Gpr.Index == ((ShaderIrOperGpr)Op.OperandC).Index)
+                                {
+                                    HandleSrc = Asg.Src;
+
+                                    break;
+                                }
+                            }
+                        }
+
+                        if (HandleSrc != null && HandleSrc is ShaderIrOperCbuf Cbuf)
+                        {
+                            string Name = StagePrefix + TextureName + "_cb" + Cbuf.Index + "_" + Cbuf.Pos;
+
+                            m_CbTextures.Add(Op, new ShaderDeclInfo(Name, Cbuf.Pos, true, Cbuf.Index));
+                        }
+                        else
+                        {
+                            throw new NotImplementedException("Shader TEX.B instruction is not fully supported!");
+                        }
+                    }
                     break;
                 }
 
@@ -199,7 +240,7 @@ namespace Ryujinx.Graphics.Gal.Shader
                     {
                         string Name = StagePrefix + UniformName + Cbuf.Index;
 
-                        ShaderDeclInfo DeclInfo = new ShaderDeclInfo(Name, Cbuf.Pos, Cbuf.Index);
+                        ShaderDeclInfo DeclInfo = new ShaderDeclInfo(Name, Cbuf.Pos, true, Cbuf.Index);
 
                         m_Uniforms.Add(Cbuf.Index, DeclInfo);
                     }
@@ -252,12 +293,13 @@ namespace Ryujinx.Graphics.Gal.Shader
 
                     if (!m_Attributes.ContainsKey(Index))
                     {
-                        DeclInfo = new ShaderDeclInfo(AttrName + GlslIndex, GlslIndex, 0, 4);
+                        DeclInfo = new ShaderDeclInfo(AttrName + GlslIndex, GlslIndex, false, 0, 4);
 
                         m_Attributes.Add(Index, DeclInfo);
                     }
 
-                    Traverse(Abuf, Abuf.Vertex);
+                    Traverse(Nodes, Abuf, Abuf.Vertex);
+
                     break;
                 }
 
diff --git a/Ryujinx.Graphics/Gal/Shader/GlslDecompiler.cs b/Ryujinx.Graphics/Gal/Shader/GlslDecompiler.cs
index aa1803a5d3..94bdd2fa17 100644
--- a/Ryujinx.Graphics/Gal/Shader/GlslDecompiler.cs
+++ b/Ryujinx.Graphics/Gal/Shader/GlslDecompiler.cs
@@ -98,6 +98,7 @@ namespace Ryujinx.Graphics.Gal.Shader
                 { ShaderIrInst.Or,     GetOrExpr     },
                 { ShaderIrInst.Stof,   GetStofExpr   },
                 { ShaderIrInst.Sub,    GetSubExpr    },
+                { ShaderIrInst.Texb,   GetTexbExpr   },
                 { ShaderIrInst.Texq,   GetTexqExpr   },
                 { ShaderIrInst.Texs,   GetTexsExpr   },
                 { ShaderIrInst.Trunc,  GetTruncExpr  },
@@ -174,10 +175,12 @@ namespace Ryujinx.Graphics.Gal.Shader
 
             string GlslCode = SB.ToString();
 
-            return new GlslProgram(
-                GlslCode,
-                Decl.Textures.Values,
-                Decl.Uniforms.Values);
+            List<ShaderDeclInfo> TextureInfo = new List<ShaderDeclInfo>();
+
+            TextureInfo.AddRange(Decl.Textures.Values);
+            TextureInfo.AddRange(IterateCbTextures());
+
+            return new GlslProgram(GlslCode, TextureInfo, Decl.Uniforms.Values);
         }
 
         private void PrintDeclHeader()
@@ -213,9 +216,27 @@ namespace Ryujinx.Graphics.Gal.Shader
 
         private void PrintDeclTextures()
         {
+            foreach (ShaderDeclInfo DeclInfo in IterateCbTextures())
+            {
+                SB.AppendLine("uniform sampler2D " + DeclInfo.Name + ";");
+            }
+
             PrintDecls(Decl.Textures, "uniform sampler2D");
         }
 
+        private IEnumerable<ShaderDeclInfo> IterateCbTextures()
+        {
+            HashSet<string> Names = new HashSet<string>();
+
+            foreach (ShaderDeclInfo DeclInfo in Decl.CbTextures.Values.OrderBy(DeclKeySelector))
+            {
+                if (Names.Add(DeclInfo.Name))
+                {
+                    yield return DeclInfo;
+                }
+            }
+        }
+
         private void PrintDeclUniforms()
         {
             if (Decl.ShaderType == GalShaderType.Vertex)
@@ -994,6 +1015,22 @@ namespace Ryujinx.Graphics.Gal.Shader
 
         private string GetSubExpr(ShaderIrOp Op) => GetBinaryExpr(Op, "-");
 
+        private string GetTexbExpr(ShaderIrOp Op)
+        {
+            ShaderIrMetaTex Meta = (ShaderIrMetaTex)Op.MetaData;
+
+            if (!Decl.CbTextures.TryGetValue(Op, out ShaderDeclInfo DeclInfo))
+            {
+                throw new InvalidOperationException();
+            }
+
+            string Coords = GetTexSamplerCoords(Op);
+
+            string Ch = "rgba".Substring(Meta.Elem, 1);
+
+            return "texture(" + DeclInfo.Name + ", " + Coords + ")." + Ch;
+        }
+
         private string GetTexqExpr(ShaderIrOp Op)
         {
             ShaderIrMetaTexq Meta = (ShaderIrMetaTexq)Op.MetaData;
diff --git a/Ryujinx.Graphics/Gal/Shader/ShaderDecodeMem.cs b/Ryujinx.Graphics/Gal/Shader/ShaderDecodeMem.cs
index aea7e744dc..5ef2e2e5f5 100644
--- a/Ryujinx.Graphics/Gal/Shader/ShaderDecodeMem.cs
+++ b/Ryujinx.Graphics/Gal/Shader/ShaderDecodeMem.cs
@@ -121,6 +121,16 @@ namespace Ryujinx.Graphics.Gal.Shader
         }
 
         public static void Tex(ShaderIrBlock Block, long OpCode)
+        {
+            EmitTex(Block, OpCode, GprHandle: false);
+        }
+
+        public static void Tex_B(ShaderIrBlock Block, long OpCode)
+        {
+            EmitTex(Block, OpCode, GprHandle: true);
+        }
+
+        private static void EmitTex(ShaderIrBlock Block, long OpCode, bool GprHandle)
         {
             //TODO: Support other formats.
             ShaderIrOperGpr[] Coords = new ShaderIrOperGpr[2];
@@ -139,7 +149,11 @@ namespace Ryujinx.Graphics.Gal.Shader
 
             int ChMask = (int)(OpCode >> 31) & 0xf;
 
-            ShaderIrNode OperC = GetOperImm13_36(OpCode);
+            ShaderIrNode OperC = GprHandle
+                ? (ShaderIrNode)GetOperGpr20   (OpCode)
+                : (ShaderIrNode)GetOperImm13_36(OpCode);
+
+            ShaderIrInst Inst = GprHandle ? ShaderIrInst.Texb : ShaderIrInst.Texs;
 
             for (int Ch = 0; Ch < 4; Ch++)
             {
@@ -147,7 +161,7 @@ namespace Ryujinx.Graphics.Gal.Shader
 
                 ShaderIrMetaTex Meta = new ShaderIrMetaTex(Ch);
 
-                ShaderIrOp Op = new ShaderIrOp(ShaderIrInst.Texs, Coords[0], Coords[1], OperC, Meta);
+                ShaderIrOp Op = new ShaderIrOp(Inst, Coords[0], Coords[1], OperC, Meta);
 
                 Block.AddNode(GetPredNode(new ShaderIrAsg(Dst, Op), OpCode));
             }
@@ -178,15 +192,15 @@ namespace Ryujinx.Graphics.Gal.Shader
 
         public static void Texs(ShaderIrBlock Block, long OpCode)
         {
-            EmitTex(Block, OpCode, ShaderIrInst.Texs);
+            EmitTexs(Block, OpCode, ShaderIrInst.Texs);
         }
 
         public static void Tlds(ShaderIrBlock Block, long OpCode)
         {
-            EmitTex(Block, OpCode, ShaderIrInst.Txlf);
+            EmitTexs(Block, OpCode, ShaderIrInst.Txlf);
         }
 
-        private static void EmitTex(ShaderIrBlock Block, long OpCode, ShaderIrInst Inst)
+        private static void EmitTexs(ShaderIrBlock Block, long OpCode, ShaderIrInst Inst)
         {
             //TODO: Support other formats.
             ShaderIrNode OperA = GetOperGpr8    (OpCode);
@@ -195,7 +209,7 @@ namespace Ryujinx.Graphics.Gal.Shader
 
             int LutIndex;
 
-            LutIndex = GetOperGpr0(OpCode).Index != ShaderIrOperGpr.ZRIndex ? 1 : 0;
+            LutIndex  = GetOperGpr0(OpCode).Index != ShaderIrOperGpr.ZRIndex ? 1 : 0;
             LutIndex |= GetOperGpr28(OpCode).Index != ShaderIrOperGpr.ZRIndex ? 2 : 0;
 
             int ChMask = MaskLut[LutIndex, (OpCode >> 50) & 7];
diff --git a/Ryujinx.Graphics/Gal/Shader/ShaderIrInst.cs b/Ryujinx.Graphics/Gal/Shader/ShaderIrInst.cs
index fd86cadb10..d197835a7a 100644
--- a/Ryujinx.Graphics/Gal/Shader/ShaderIrInst.cs
+++ b/Ryujinx.Graphics/Gal/Shader/ShaderIrInst.cs
@@ -47,6 +47,7 @@ namespace Ryujinx.Graphics.Gal.Shader
         Ftos,
         Ftou,
         Ipa,
+        Texb,
         Texs,
         Trunc,
         F_End,
@@ -83,7 +84,7 @@ namespace Ryujinx.Graphics.Gal.Shader
         Bra,
         Exit,
         Kil,
-        
+
         Emit,
         Cut
     }
diff --git a/Ryujinx.Graphics/Gal/Shader/ShaderOpCodeTable.cs b/Ryujinx.Graphics/Gal/Shader/ShaderOpCodeTable.cs
index 3f20dc4465..95b8e467d2 100644
--- a/Ryujinx.Graphics/Gal/Shader/ShaderOpCodeTable.cs
+++ b/Ryujinx.Graphics/Gal/Shader/ShaderOpCodeTable.cs
@@ -114,6 +114,7 @@ namespace Ryujinx.Graphics.Gal.Shader
             Set("0101110000101x", ShaderDecode.Shr_R);
             Set("1110111111110x", ShaderDecode.St_A);
             Set("110000xxxx111x", ShaderDecode.Tex);
+            Set("1101111010111x", ShaderDecode.Tex_B);
             Set("1101111101001x", ShaderDecode.Texq);
             Set("1101100xxxxxxx", ShaderDecode.Texs);
             Set("1101101xxxxxxx", ShaderDecode.Tlds);
diff --git a/Ryujinx.Graphics/Gal/ShaderDeclInfo.cs b/Ryujinx.Graphics/Gal/ShaderDeclInfo.cs
index d400850c86..ef47ca2e1b 100644
--- a/Ryujinx.Graphics/Gal/ShaderDeclInfo.cs
+++ b/Ryujinx.Graphics/Gal/ShaderDeclInfo.cs
@@ -4,14 +4,21 @@ namespace Ryujinx.Graphics.Gal
     {
         public string Name { get; private set; }
 
-        public int Index { get; private set; }
-        public int Cbuf  { get; private set; }
-        public int Size  { get; private set; }
+        public int  Index { get; private set; }
+        public bool IsCb  { get; private set; }
+        public int  Cbuf  { get; private set; }
+        public int  Size  { get; private set; }
 
-        public ShaderDeclInfo(string Name, int Index, int Cbuf = 0, int Size = 1)
+        public ShaderDeclInfo(
+            string Name,
+            int    Index,
+            bool   IsCb = false,
+            int    Cbuf = 0,
+            int    Size = 1)
         {
             this.Name  = Name;
             this.Index = Index;
+            this.IsCb  = IsCb;
             this.Cbuf  = Cbuf;
             this.Size  = Size;
         }
diff --git a/Ryujinx.HLE/Gpu/Engines/NvGpuEngine3d.cs b/Ryujinx.HLE/Gpu/Engines/NvGpuEngine3d.cs
index 0576601f5e..38f8d1c9b4 100644
--- a/Ryujinx.HLE/Gpu/Engines/NvGpuEngine3d.cs
+++ b/Ryujinx.HLE/Gpu/Engines/NvGpuEngine3d.cs
@@ -309,7 +309,7 @@ namespace Ryujinx.HLE.Gpu.Engines
         private void SetStencil(GalPipelineState State)
         {
             State.StencilTestEnabled = (ReadRegister(NvGpuEngine3dReg.StencilEnable) & 1) != 0;
-            
+
             if (State.StencilTestEnabled)
             {
                 State.StencilBackFuncFunc = (GalComparisonOp)ReadRegister(NvGpuEngine3dReg.StencilBackFuncFunc);
@@ -364,17 +364,26 @@ namespace Ryujinx.HLE.Gpu.Engines
 
             int TextureCbIndex = ReadRegister(NvGpuEngine3dReg.TextureCbIndex);
 
-            //Note: On the emulator renderer, Texture Unit 0 is
-            //reserved for drawing the frame buffer.
-            int TexIndex = 1;
+            int TexIndex = 0;
 
             for (int Index = 0; Index < Keys.Length; Index++)
             {
                 foreach (ShaderDeclInfo DeclInfo in Gpu.Renderer.Shader.GetTextureUsage(Keys[Index]))
                 {
-                    long Position = ConstBuffers[Index][TextureCbIndex].Position;
+                    long Position;
 
-                    UploadTexture(Vmm, Position, TexIndex, DeclInfo.Index);
+                    if (DeclInfo.IsCb)
+                    {
+                        Position = ConstBuffers[Index][DeclInfo.Cbuf].Position;
+                    }
+                    else
+                    {
+                        Position = ConstBuffers[Index][TextureCbIndex].Position;
+                    }
+
+                    int TextureHandle = Vmm.ReadInt32(Position + DeclInfo.Index * 4);
+
+                    UploadTexture(Vmm, TexIndex, TextureHandle);
 
                     Gpu.Renderer.Shader.EnsureTextureBinding(DeclInfo.Name, TexIndex);
 
@@ -383,12 +392,8 @@ namespace Ryujinx.HLE.Gpu.Engines
             }
         }
 
-        private void UploadTexture(NvGpuVmm Vmm, long BasePosition, int TexIndex, int HndIndex)
+        private void UploadTexture(NvGpuVmm Vmm, int TexIndex, int TextureHandle)
         {
-            long Position = BasePosition + HndIndex * 4;
-
-            int TextureHandle = Vmm.ReadInt32(Position);
-
             if (TextureHandle == 0)
             {
                 //TODO: Is this correct?
@@ -601,6 +606,10 @@ namespace Ryujinx.HLE.Gpu.Engines
 
                 Gpu.Renderer.Rasterizer.DrawArrays(VertexFirst, VertexCount, PrimType);
             }
+
+            //Is the GPU really clearing those registers after draw?
+            WriteRegister(NvGpuEngine3dReg.IndexBatchFirst, 0);
+            WriteRegister(NvGpuEngine3dReg.IndexBatchCount, 0);
         }
 
         private void QueryControl(NvGpuVmm Vmm, NvGpuPBEntry PBEntry)