diff --git a/src/Ryujinx.Graphics.Gpu/Engine/Compute/ComputeClass.cs b/src/Ryujinx.Graphics.Gpu/Engine/Compute/ComputeClass.cs
index 8227a7ff18..d8103ac719 100644
--- a/src/Ryujinx.Graphics.Gpu/Engine/Compute/ComputeClass.cs
+++ b/src/Ryujinx.Graphics.Gpu/Engine/Compute/ComputeClass.cs
@@ -151,8 +151,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Compute
 
             ShaderProgramInfo info = cs.Shaders[0].Info;
 
-            bool hasUnaligned = _channel.BufferManager.HasUnalignedStorageBuffers;
-
             for (int index = 0; index < info.SBuffers.Count; index++)
             {
                 BufferDescriptor sb = info.SBuffers[index];
@@ -177,9 +175,17 @@ namespace Ryujinx.Graphics.Gpu.Engine.Compute
                 _channel.BufferManager.SetComputeStorageBuffer(sb.Slot, sbDescriptor.PackAddress(), size, sb.Flags);
             }
 
-            if ((_channel.BufferManager.HasUnalignedStorageBuffers) != hasUnaligned)
+            if (_channel.BufferManager.HasUnalignedStorageBuffers != computeState.HasUnalignedStorageBuffer)
             {
                 // Refetch the shader, as assumptions about storage buffer alignment have changed.
+                computeState = new GpuChannelComputeState(
+                    qmd.CtaThreadDimension0,
+                    qmd.CtaThreadDimension1,
+                    qmd.CtaThreadDimension2,
+                    localMemorySize,
+                    sharedMemorySize,
+                    _channel.BufferManager.HasUnalignedStorageBuffers);
+
                 cs = memoryManager.Physical.ShaderCache.GetComputeShader(_channel, poolState, computeState, shaderGpuVa);
 
                 _context.Renderer.Pipeline.SetProgram(cs.HostProgram);