From 65611e5b51924008c7c611f75290db12a705e22d Mon Sep 17 00:00:00 2001
From: SachinVin <sachinvinayak2000@gmail.com>
Date: Sat, 14 May 2022 17:08:38 +0530
Subject: [PATCH] Shader jit: Save and restore `LOOPCOUNT_REG` for nested
 loops,

also add the assert back for nested loops
update test
---
 .../shader/shader_jit_x64_compiler.cpp        | 110 +++++-------------
 .../shader/shader_jit_x64_compiler.cpp        |  12 +-
 2 files changed, 37 insertions(+), 85 deletions(-)

diff --git a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp b/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
index 85eab4428..147e0030a 100644
--- a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
+++ b/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
@@ -108,11 +108,9 @@ TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") {
     const auto sh_temp = SourceRegister::MakeTemporary(0);
     const auto sh_output = DestRegister::MakeOutput(0);
 
-    std::array<Common::Vec4<u8>, 2> loop_parms{Common::Vec4<u8>{4, 0, 1, 0},
-                                               Common::Vec4<u8>{4, 0, 1, 0}};
-
     auto shader_test = ShaderTest({
         // clang-format off
+        {OpCode::Id::MOV, sh_temp, sh_input},
         {OpCode::Id::LOOP, 0},
             {OpCode::Id::LOOP, 1},
                 {OpCode::Id::ADD, sh_temp, sh_temp, sh_input},
@@ -123,87 +121,39 @@ TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") {
         // clang-format on
     });
 
-    shader_test.shader_setup->uniforms.i[0] = loop_parms[0];
-    shader_test.shader_setup->uniforms.i[1] = loop_parms[0];
-
-    const auto run_test_helper = [&shader_test](float input) {
-        Pica::Shader::UnitState shader_unit_jit;
-        Pica::Shader::UnitState shader_unit_inter;
-        shader_test.RunJit(shader_unit_jit, input);
-        shader_test.RunInterpreter(shader_unit_inter, input);
-
-        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() ==
-                Approx(shader_unit_inter.registers.output[0].x.ToFloat32()));
-        REQUIRE(shader_unit_jit.address_registers[2] == shader_unit_inter.address_registers[2]);
-    };
     {
-        // Sanity check
+        shader_test.shader_setup->uniforms.i[0] = {4, 0, 1, 0};
+        shader_test.shader_setup->uniforms.i[1] = {4, 0, 1, 0};
+        Common::Vec4<u8> loop_parms{shader_test.shader_setup->uniforms.i[0]};
+
+        const int expected_aL = loop_parms[1] + ((loop_parms[0] + 1) * loop_parms[2]);
+        const float input = 1.0f;
+        const float expected_out = (((shader_test.shader_setup->uniforms.i[0][0] + 1) *
+                                     (shader_test.shader_setup->uniforms.i[1][0] + 1)) *
+                                    input) +
+                                   input;
+
         Pica::Shader::UnitState shader_unit_jit;
-        shader_test.RunJit(shader_unit_jit, 1.0f);
-        REQUIRE(shader_unit_jit.address_registers[2] == 6);
-        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() == Approx(25.0f));
-
-        Pica::Shader::UnitState shader_unit_inter;
-        shader_test.RunInterpreter(shader_unit_inter, 2.0f);
-        REQUIRE(shader_unit_inter.address_registers[2] == 6);
-        REQUIRE(shader_unit_inter.registers.output[0].x.ToFloat32() == Approx(50.0f));
-    }
-    run_test_helper(-5.f);
-    run_test_helper(0.f);
-    run_test_helper(2.f);
-    run_test_helper(6.f);
-    run_test_helper(79.7262742773f);
-}
-
-TEST_CASE("Nested Loop Randomized", "[video_core][shader][shader_jit]") {
-    const auto sh_input = SourceRegister::MakeInput(0);
-    const auto sh_temp = SourceRegister::MakeTemporary(0);
-    const auto sh_output = DestRegister::MakeOutput(0);
-
-    auto shader_test = ShaderTest({
-        // clang-format off
-        {OpCode::Id::LOOP, 0},
-            {OpCode::Id::LOOP, 1},
-                 {OpCode::Id::LOOP, 2},
-                    {OpCode::Id::LOOP, 3},
-                        {OpCode::Id::ADD, sh_temp, sh_temp, sh_input},
-                    {Type::EndLoop},
-                {Type::EndLoop},
-            {Type::EndLoop},
-        {Type::EndLoop},
-
-        {OpCode::Id::MOV, sh_output, sh_temp},
-        {OpCode::Id::END},
-        // clang-format on
-    });
-
-    const auto generate_loop_parms = [] {
-        u8 iterations = 1 + rand();
-        u8 initial = 1 + rand();
-        u8 increment = 1 + rand();
-
-        Common::Vec4<u8> loop_parm{iterations, initial, increment, 0};
-        return Common::Vec4<u8>{iterations, initial, increment, 0};
-    };
-
-    const auto run_test_helper = [&shader_test](float input) {
-        Pica::Shader::UnitState shader_unit_jit;
-        Pica::Shader::UnitState shader_unit_inter;
         shader_test.RunJit(shader_unit_jit, input);
-        shader_test.RunInterpreter(shader_unit_inter, input);
 
-        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() ==
-                Approx(shader_unit_inter.registers.output[0].x.ToFloat32()));
-        REQUIRE(shader_unit_jit.address_registers[2] == shader_unit_inter.address_registers[2]);
-    };
+        REQUIRE(shader_unit_jit.address_registers[2] == expected_aL);
+        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() == Approx(expected_out));
+    }
+    {
+        shader_test.shader_setup->uniforms.i[0] = {9, 0, 2, 0};
+        shader_test.shader_setup->uniforms.i[1] = {7, 0, 1, 0};
 
-    srand(time(0));
-    for (int i = 0; i < 10; i++) {
-        shader_test.shader_setup->uniforms.i[0] = generate_loop_parms();
-        shader_test.shader_setup->uniforms.i[1] = generate_loop_parms();
-        shader_test.shader_setup->uniforms.i[2] = generate_loop_parms();
-        shader_test.shader_setup->uniforms.i[3] = generate_loop_parms();
-        float input = -(RAND_MAX / 2) + rand();
-        run_test_helper(input);
+        const Common::Vec4<u8> loop_parms{shader_test.shader_setup->uniforms.i[0]};
+        const int expected_aL = loop_parms[1] + ((loop_parms[0] + 1) * loop_parms[2]);
+        const float input = 1.0f;
+        const float expected_out = (((shader_test.shader_setup->uniforms.i[0][0] + 1) *
+                                     (shader_test.shader_setup->uniforms.i[1][0] + 1)) *
+                                    input) +
+                                   input;
+        Pica::Shader::UnitState shader_unit_jit;
+        shader_test.RunJit(shader_unit_jit, input);
+
+        REQUIRE(shader_unit_jit.address_registers[2] == expected_aL);
+        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() == Approx(expected_out));
     }
 }
diff --git a/src/video_core/shader/shader_jit_x64_compiler.cpp b/src/video_core/shader/shader_jit_x64_compiler.cpp
index 5753187e9..b010aaa6f 100644
--- a/src/video_core/shader/shader_jit_x64_compiler.cpp
+++ b/src/video_core/shader/shader_jit_x64_compiler.cpp
@@ -164,8 +164,10 @@ static void LogCritical(const char* msg) {
 
 void JitShader::Compile_Assert(bool condition, const char* msg) {
     if (!condition) {
+        ABI_PushRegistersAndAdjustStack(*this, PersistentCallerSavedRegs(), 0);
         mov(ABI_PARAM1, reinterpret_cast<std::size_t>(msg));
         CallFarFunction(*this, LogCritical);
+        ABI_PopRegistersAndAdjustStack(*this, PersistentCallerSavedRegs(), 0);
     }
 }
 
@@ -725,10 +727,10 @@ void JitShader::Compile_IF(Instruction instr) {
 void JitShader::Compile_LOOP(Instruction instr) {
     Compile_Assert(instr.flow_control.dest_offset >= program_counter,
                    "Backwards loops not supported");
+    Compile_Assert(loop_depth < 1, "Nested loops may not be supported");
     if (loop_depth++) {
-        // LOOPCOUNT_REG is a "global", so we don't save it here.
-        push(LOOPINC.cvt64());
-        push(LOOPCOUNT.cvt64());
+        const auto loop_save_regs = BuildRegSet({LOOPCOUNT_REG, LOOPINC, LOOPCOUNT});
+        ABI_PushRegistersAndAdjustStack(*this, loop_save_regs, 0);
     }
 
     // This decodes the fields from the integer uniform at index instr.flow_control.int_uniform_id.
@@ -759,8 +761,8 @@ void JitShader::Compile_LOOP(Instruction instr) {
     loop_break_labels.pop_back();
 
     if (--loop_depth) {
-        pop(LOOPCOUNT.cvt64());
-        pop(LOOPINC.cvt64());
+        const auto loop_save_regs = BuildRegSet({LOOPCOUNT_REG, LOOPINC, LOOPCOUNT});
+        ABI_PopRegistersAndAdjustStack(*this, loop_save_regs, 0);
     }
 }