diff --git a/Ryujinx.Graphics.Gpu/Engine/Threed/DrawManager.cs b/Ryujinx.Graphics.Gpu/Engine/Threed/DrawManager.cs
index ed8ed2064b..5a659f5550 100644
--- a/Ryujinx.Graphics.Gpu/Engine/Threed/DrawManager.cs
+++ b/Ryujinx.Graphics.Gpu/Engine/Threed/DrawManager.cs
@@ -19,6 +19,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         private readonly GpuChannel _channel;
         private readonly DeviceStateWithShadow<ThreedClassState> _state;
         private readonly DrawState _drawState;
+        private readonly SpecializationStateUpdater _currentSpecState;
         private bool _topologySet;
 
         private bool _instancedDrawPending;
@@ -44,12 +45,14 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         /// <param name="channel">GPU channel</param>
         /// <param name="state">Channel state</param>
         /// <param name="drawState">Draw state</param>
-        public DrawManager(GpuContext context, GpuChannel channel, DeviceStateWithShadow<ThreedClassState> state, DrawState drawState)
+        /// <param name="spec">Specialization state updater</param>
+        public DrawManager(GpuContext context, GpuChannel channel, DeviceStateWithShadow<ThreedClassState> state, DrawState drawState, SpecializationStateUpdater spec)
         {
             _context = context;
             _channel = channel;
             _state = state;
             _drawState = drawState;
+            _currentSpecState = spec;
         }
 
         /// <summary>
@@ -132,6 +135,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _drawState.FirstIndex = firstIndex;
             _drawState.IndexCount = indexCount;
+            _currentSpecState.SetHasConstantBufferDrawParameters(false);
 
             engine.UpdateState();
 
@@ -256,6 +260,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
             if (_drawState.Topology != topology || !_topologySet)
             {
                 _context.Renderer.Pipeline.SetPrimitiveTopology(topology);
+                _currentSpecState.SetTopology(topology);
                 _drawState.Topology = topology;
                 _topologySet = true;
             }
@@ -452,7 +457,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
             _state.State.FirstInstance = (uint)firstInstance;
 
             _drawState.DrawIndexed = indexed;
-            _drawState.HasConstantBufferDrawParameters = true;
+            _currentSpecState.SetHasConstantBufferDrawParameters(true);
 
             engine.UpdateState();
 
@@ -469,7 +474,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
             _state.State.FirstInstance = 0;
 
             _drawState.DrawIndexed = false;
-            _drawState.HasConstantBufferDrawParameters = false;
 
             if (renderEnable == ConditionalRenderEnabled.Host)
             {
@@ -527,7 +531,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _drawState.DrawIndexed = indexed;
             _drawState.DrawIndirect = true;
-            _drawState.HasConstantBufferDrawParameters = true;
+            _currentSpecState.SetHasConstantBufferDrawParameters(true);
 
             engine.UpdateState();
 
@@ -561,7 +565,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _drawState.DrawIndexed = false;
             _drawState.DrawIndirect = false;
-            _drawState.HasConstantBufferDrawParameters = false;
 
             if (renderEnable == ConditionalRenderEnabled.Host)
             {
diff --git a/Ryujinx.Graphics.Gpu/Engine/Threed/SpecializationStateUpdater.cs b/Ryujinx.Graphics.Gpu/Engine/Threed/SpecializationStateUpdater.cs
new file mode 100644
index 0000000000..9e888f506c
--- /dev/null
+++ b/Ryujinx.Graphics.Gpu/Engine/Threed/SpecializationStateUpdater.cs
@@ -0,0 +1,280 @@
+using Ryujinx.Common.Memory;
+using Ryujinx.Graphics.GAL;
+using Ryujinx.Graphics.Gpu.Shader;
+using Ryujinx.Graphics.Shader;
+
+namespace Ryujinx.Graphics.Gpu.Engine.Threed
+{
+    /// <summary>
+    /// Maintains a "current" specialiation state, and provides a flag to check if it has changed meaningfully.
+    /// </summary>
+    internal class SpecializationStateUpdater
+    {
+        private GpuChannelGraphicsState _graphics;
+        private GpuChannelPoolState _pool;
+
+        private bool _usesDrawParameters;
+        private bool _usesTopology;
+
+        private bool _changed;
+
+        /// <summary>
+        /// Signal that the specialization state has changed.
+        /// </summary>
+        private void Signal()
+        {
+            _changed = true;
+        }
+
+        /// <summary>
+        /// Checks if the specialization state has changed since the last check.
+        /// </summary>
+        /// <returns>True if it has changed, false otherwise</returns>
+        public bool HasChanged()
+        {
+            if (_changed)
+            {
+                _changed = false;
+                return true;
+            }
+            else
+            {
+                return false;
+            }
+        }
+
+        /// <summary>
+        /// Sets the active shader, clearing the dirty state and recording if certain specializations are noteworthy.
+        /// </summary>
+        /// <param name="gs">The active shader</param>
+        public void SetShader(CachedShaderProgram gs)
+        {
+            _usesDrawParameters = gs.Shaders[1]?.Info.UsesDrawParameters ?? false;
+            _usesTopology = gs.SpecializationState.IsPrimitiveTopologyQueried();
+
+            _changed = false;
+        }
+
+        /// <summary>
+        /// Get the current graphics state.
+        /// </summary>
+        /// <returns>GPU graphics state</returns>
+        public ref GpuChannelGraphicsState GetGraphicsState()
+        {
+            return ref _graphics;
+        }
+
+        /// <summary>
+        /// Get the current pool state.
+        /// </summary>
+        /// <returns>GPU pool state</returns>
+        public ref GpuChannelPoolState GetPoolState()
+        {
+            return ref _pool;
+        }
+
+        /// <summary>
+        /// Early Z force enable.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetEarlyZForce(bool value)
+        {
+            _graphics.EarlyZForce = value;
+
+            Signal();
+        }
+
+        /// <summary>
+        /// Primitive topology of current draw.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetTopology(PrimitiveTopology value)
+        {
+            if (value != _graphics.Topology)
+            {
+                _graphics.Topology = value;
+
+                if (_usesTopology)
+                {
+                    Signal();
+                }
+            }
+        }
+
+        /// <summary>
+        /// Tessellation mode.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetTessellationMode(TessMode value)
+        {
+            if (value.Packed != _graphics.TessellationMode.Packed)
+            {
+                _graphics.TessellationMode = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Updates alpha-to-coverage state, and sets it as changed.
+        /// </summary>
+        /// <param name="enable">Whether alpha-to-coverage is enabled</param>
+        /// <param name="ditherEnable">Whether alpha-to-coverage dithering is enabled</param>
+        public void SetAlphaToCoverageEnable(bool enable, bool ditherEnable)
+        {
+            _graphics.AlphaToCoverageEnable = enable;
+            _graphics.AlphaToCoverageDitherEnable = ditherEnable;
+
+            Signal();
+        }
+
+        /// <summary>
+        /// Indicates whether the viewport transform is disabled.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetViewportTransformDisable(bool value)
+        {
+            if (value != _graphics.ViewportTransformDisable)
+            {
+                _graphics.ViewportTransformDisable = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Depth mode zero to one or minus one to one.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetDepthMode(bool value)
+        {
+            if (value != _graphics.DepthMode)
+            {
+                _graphics.DepthMode = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Indicates if the point size is set on the shader or is fixed.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetProgramPointSizeEnable(bool value)
+        {
+            if (value != _graphics.ProgramPointSizeEnable)
+            {
+                _graphics.ProgramPointSizeEnable = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Point size used if <see cref="SetProgramPointSizeEnable" /> is provided false.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetPointSize(float value)
+        {
+            if (value != _graphics.PointSize)
+            {
+                _graphics.PointSize = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Updates alpha test specialization state, and sets it as changed.
+        /// </summary>
+        /// <param name="enable">Whether alpha test is enabled</param>
+        /// <param name="reference">The value to compare with the fragment output alpha</param>
+        /// <param name="op">The comparison that decides if the fragment should be discarded</param>
+        public void SetAlphaTest(bool enable, float reference, CompareOp op)
+        {
+            _graphics.AlphaTestEnable = enable;
+            _graphics.AlphaTestReference = reference;
+            _graphics.AlphaTestCompare = op;
+
+            Signal();
+        }
+
+        /// <summary>
+        /// Updates the type of the vertex attributes consumed by the shader.
+        /// </summary>
+        /// <param name="state">The new state</param>
+        public void SetAttributeTypes(ref Array32<VertexAttribState> state)
+        {
+            bool changed = false;
+            ref Array32<AttributeType> attributeTypes = ref _graphics.AttributeTypes;
+
+            for (int location = 0; location < state.Length; location++)
+            {
+                VertexAttribType type = state[location].UnpackType();
+
+                AttributeType value = type switch
+                {
+                    VertexAttribType.Sint => AttributeType.Sint,
+                    VertexAttribType.Uint => AttributeType.Uint,
+                    _ => AttributeType.Float
+                };
+
+                if (attributeTypes[location] != value)
+                {
+                    attributeTypes[location] = value;
+                    changed = true;
+                }
+            }
+
+            if (changed)
+            {
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Indicates that the draw is writing the base vertex, base instance and draw index to Constant Buffer 0.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetHasConstantBufferDrawParameters(bool value)
+        {
+            if (value != _graphics.HasConstantBufferDrawParameters)
+            {
+                _graphics.HasConstantBufferDrawParameters = value;
+
+                if (_usesDrawParameters)
+                {
+                    Signal();
+                }
+            }
+        }
+
+        /// <summary>
+        /// Indicates that any storage buffer use is unaligned.
+        /// </summary>
+        /// <param name="value">The new value</param>
+        public void SetHasUnalignedStorageBuffer(bool value)
+        {
+            if (value != _graphics.HasUnalignedStorageBuffer)
+            {
+                _graphics.HasUnalignedStorageBuffer = value;
+
+                Signal();
+            }
+        }
+
+        /// <summary>
+        /// Sets the GPU pool state.
+        /// </summary>
+        /// <param name="state">The new state</param>
+        public void SetPoolState(GpuChannelPoolState state)
+        {
+            if (!state.Equals(_pool))
+            {
+                _pool = state;
+
+                Signal();
+            }
+        }
+    }
+}
diff --git a/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs b/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
index fe7e0d09f0..b611f4e700 100644
--- a/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
+++ b/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
@@ -1,6 +1,7 @@
 using Ryujinx.Common.Logging;
 using Ryujinx.Common.Memory;
 using Ryujinx.Graphics.GAL;
+using Ryujinx.Graphics.Gpu.Engine.GPFifo;
 using Ryujinx.Graphics.Gpu.Engine.Types;
 using Ryujinx.Graphics.Gpu.Image;
 using Ryujinx.Graphics.Gpu.Shader;
@@ -16,9 +17,9 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
     /// </summary>
     class StateUpdater
     {
-        public const int ShaderStateIndex = 16;
+        public const int ShaderStateIndex = 26;
         public const int RasterizerStateIndex = 15;
-        public const int ScissorStateIndex = 18;
+        public const int ScissorStateIndex = 16;
         public const int VertexBufferStateIndex = 0;
         public const int PrimitiveRestartStateIndex = 12;
 
@@ -31,6 +32,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
         private readonly ShaderProgramInfo[] _currentProgramInfo;
         private ShaderSpecializationState _shaderSpecState;
+        private SpecializationStateUpdater _currentSpecState;
 
         private ProgramPipelineState _pipeline;
 
@@ -54,15 +56,17 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         /// <param name="channel">GPU channel</param>
         /// <param name="state">3D engine state</param>
         /// <param name="drawState">Draw state</param>
-        public StateUpdater(GpuContext context, GpuChannel channel, DeviceStateWithShadow<ThreedClassState> state, DrawState drawState)
+        /// <param name="spec">Specialization state updater</param>
+        public StateUpdater(GpuContext context, GpuChannel channel, DeviceStateWithShadow<ThreedClassState> state, DrawState drawState, SpecializationStateUpdater spec)
         {
             _context = context;
             _channel = channel;
             _state = state;
             _drawState = drawState;
             _currentProgramInfo = new ShaderProgramInfo[Constants.ShaderStages];
+            _currentSpecState = spec;
 
-            // ShaderState must be updated after other state updates, as pipeline state is sent to the backend when compiling new shaders.
+            // ShaderState must be updated after other state updates, as specialization/pipeline state is used when fetching shaders.
             // Render target state must appear after shader state as it depends on information from the currently bound shader.
             // Rasterizer and scissor states are checked by render target clear, their indexes
             // must be updated on the constants "RasterizerStateIndex" and "ScissorStateIndex" if modified.
@@ -101,6 +105,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                     nameof(ThreedClassState.DepthTestFunc)),
 
                 new StateUpdateCallbackEntry(UpdateTessellationState,
+                    nameof(ThreedClassState.TessMode),
                     nameof(ThreedClassState.TessOuterLevel),
                     nameof(ThreedClassState.TessInnerLevel),
                     nameof(ThreedClassState.PatchVertices)),
@@ -138,17 +143,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
                 new StateUpdateCallbackEntry(UpdateRasterizerState, nameof(ThreedClassState.RasterizeEnable)),
 
-                new StateUpdateCallbackEntry(UpdateShaderState,
-                    nameof(ThreedClassState.ShaderBaseAddress),
-                    nameof(ThreedClassState.ShaderState)),
-
-                new StateUpdateCallbackEntry(UpdateRenderTargetState,
-                    nameof(ThreedClassState.RtColorState),
-                    nameof(ThreedClassState.RtDepthStencilState),
-                    nameof(ThreedClassState.RtControl),
-                    nameof(ThreedClassState.RtDepthStencilSize),
-                    nameof(ThreedClassState.RtDepthStencilEnable)),
-
                 new StateUpdateCallbackEntry(UpdateScissorState,
                     nameof(ThreedClassState.ScissorState),
                     nameof(ThreedClassState.ScreenScissorState)),
@@ -179,7 +173,21 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
                 new StateUpdateCallbackEntry(UpdateMultisampleState,
                     nameof(ThreedClassState.AlphaToCoverageDitherEnable),
-                    nameof(ThreedClassState.MultisampleControl))
+                    nameof(ThreedClassState.MultisampleControl)),
+
+                new StateUpdateCallbackEntry(UpdateEarlyZState,
+                    nameof(ThreedClassState.EarlyZForce)),
+
+                new StateUpdateCallbackEntry(UpdateShaderState,
+                    nameof(ThreedClassState.ShaderBaseAddress),
+                    nameof(ThreedClassState.ShaderState)),
+
+                new StateUpdateCallbackEntry(UpdateRenderTargetState,
+                    nameof(ThreedClassState.RtColorState),
+                    nameof(ThreedClassState.RtDepthStencilState),
+                    nameof(ThreedClassState.RtControl),
+                    nameof(ThreedClassState.RtDepthStencilSize),
+                    nameof(ThreedClassState.RtDepthStencilEnable)),
             });
         }
 
@@ -209,17 +217,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public void Update()
         {
-            // If any state that the shader depends on changed,
-            // then we may need to compile/bind a different version
-            // of the shader for the new state.
-            if (_shaderSpecState != null)
-            {
-                if (!_shaderSpecState.MatchesGraphics(_channel, GetPoolState(), GetGraphicsState(), _vsUsesDrawParameters, false))
-                {
-                    ForceShaderUpdate();
-                }
-            }
-
             // The vertex buffer size is calculated using a different
             // method when doing indexed draws, so we need to make sure
             // to update the vertex buffers if we are doing a regular
@@ -271,6 +268,18 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _updateTracker.Update(ulong.MaxValue);
 
+            // If any state that the shader depends on changed,
+            // then we may need to compile/bind a different version
+            // of the shader for the new state.
+            if (_shaderSpecState != null && _currentSpecState.HasChanged())
+            {
+                if (!_shaderSpecState.MatchesGraphics(_channel, ref _currentSpecState.GetPoolState(), ref _currentSpecState.GetGraphicsState(), _vsUsesDrawParameters, false))
+                {
+                    // Shader must be reloaded. _vtgWritesRtLayer should not change.
+                    UpdateShaderState();
+                }
+            }
+
             CommitBindings();
 
             if (tfEnable && !_prevTfEnable)
@@ -302,7 +311,8 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             if (!_channel.TextureManager.CommitGraphicsBindings(_shaderSpecState) || (buffers.HasUnalignedStorageBuffers != hasUnaligned))
             {
-                // Shader must be reloaded.
+                _currentSpecState.SetHasUnalignedStorageBuffer(buffers.HasUnalignedStorageBuffers);
+                // Shader must be reloaded. _vtgWritesRtLayer should not change.
                 UpdateShaderState();
             }
 
@@ -351,6 +361,8 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 _state.State.PatchVertices,
                 _state.State.TessOuterLevel.AsSpan(),
                 _state.State.TessInnerLevel.AsSpan());
+
+            _currentSpecState.SetTessellationMode(_state.State.TessMode);
         }
 
         /// <summary>
@@ -611,6 +623,11 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 _state.State.AlphaTestEnable,
                 _state.State.AlphaTestRef,
                 _state.State.AlphaTestFunc);
+
+            _currentSpecState.SetAlphaTest(
+                _state.State.AlphaTestEnable,
+                _state.State.AlphaTestRef,
+                _state.State.AlphaTestFunc);
         }
 
         /// <summary>
@@ -710,6 +727,9 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _context.Renderer.Pipeline.SetDepthMode(GetDepthMode());
             _context.Renderer.Pipeline.SetViewports(viewports, disableTransform);
+
+            _currentSpecState.SetViewportTransformDisable(_state.State.ViewportTransformEnable == 0);
+            _currentSpecState.SetDepthMode(GetDepthMode() == DepthMode.MinusOneToOne);
         }
 
         /// <summary>
@@ -847,6 +867,8 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _channel.TextureManager.SetGraphicsTexturePool(texturePool.Address.Pack(), texturePool.MaximumId);
             _channel.TextureManager.SetGraphicsTextureBufferIndex((int)_state.State.TextureBufferIndex);
+
+            _currentSpecState.SetPoolState(GetPoolState());
         }
 
         /// <summary>
@@ -887,6 +909,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _pipeline.SetVertexAttribs(vertexAttribs);
             _context.Renderer.Pipeline.SetVertexAttribs(vertexAttribs);
+            _currentSpecState.SetAttributeTypes(ref _state.State.VertexAttribState);
         }
 
         /// <summary>
@@ -914,6 +937,9 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
             Origin origin = (_state.State.PointCoordReplace & 4) == 0 ? Origin.LowerLeft : Origin.UpperLeft;
 
             _context.Renderer.Pipeline.SetPointParameters(size, isProgramPointSize, enablePointSprite, origin);
+
+            _currentSpecState.SetProgramPointSizeEnable(isProgramPointSize);
+            _currentSpecState.SetPointSize(size);
         }
 
         /// <summary>
@@ -1212,6 +1238,16 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 alphaToCoverageEnable,
                 _state.State.AlphaToCoverageDitherEnable,
                 alphaToOneEnable));
+
+            _currentSpecState.SetAlphaToCoverageEnable(alphaToCoverageEnable, _state.State.AlphaToCoverageDitherEnable);
+        }
+
+        /// <summary>
+        /// Updates the early z flag, based on guest state.
+        /// </summary>
+        private void UpdateEarlyZState()
+        {
+            _currentSpecState.SetEarlyZForce(_state.State.EarlyZForce);
         }
 
         /// <summary>
@@ -1239,10 +1275,10 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 addressesSpan[index] = baseAddress + shader.Offset;
             }
 
-            GpuChannelPoolState poolState = GetPoolState();
-            GpuChannelGraphicsState graphicsState = GetGraphicsState();
+            CachedShaderProgram gs = shaderCache.GetGraphicsShader(ref _state.State, ref _pipeline, _channel, ref _currentSpecState.GetPoolState(), ref _currentSpecState.GetGraphicsState(), addresses);
 
-            CachedShaderProgram gs = shaderCache.GetGraphicsShader(ref _state.State, ref _pipeline, _channel, poolState, graphicsState, addresses);
+            // Consume the modified flag for spec state so that it isn't checked again.
+            _currentSpecState.SetShader(gs);
 
             _shaderSpecState = gs.SpecializationState;
 
@@ -1289,46 +1325,6 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 (int)_state.State.TextureBufferIndex);
         }
 
-        /// <summary>
-        /// Gets the current GPU channel state for shader creation or compatibility verification.
-        /// </summary>
-        /// <returns>Current GPU channel state</returns>
-        private GpuChannelGraphicsState GetGraphicsState()
-        {
-            ref var vertexAttribState = ref _state.State.VertexAttribState;
-
-            Array32<AttributeType> attributeTypes = new Array32<AttributeType>();
-
-            for (int location = 0; location < attributeTypes.Length; location++)
-            {
-                VertexAttribType type = vertexAttribState[location].UnpackType();
-
-                attributeTypes[location] = type switch
-                {
-                    VertexAttribType.Sint => AttributeType.Sint,
-                    VertexAttribType.Uint => AttributeType.Uint,
-                    _ => AttributeType.Float
-                };
-            }
-
-            return new GpuChannelGraphicsState(
-                _state.State.EarlyZForce,
-                _drawState.Topology,
-                _state.State.TessMode,
-                (_state.State.MultisampleControl & 1) != 0,
-                _state.State.AlphaToCoverageDitherEnable,
-                _state.State.ViewportTransformEnable == 0,
-                GetDepthMode() == DepthMode.MinusOneToOne,
-                _state.State.VertexProgramPointSize,
-                _state.State.PointSize,
-                _state.State.AlphaTestEnable,
-                _state.State.AlphaTestFunc,
-                _state.State.AlphaTestRef,
-                ref attributeTypes,
-                _drawState.HasConstantBufferDrawParameters,
-                _channel.BufferManager.HasUnalignedStorageBuffers);
-        }
-
         /// <summary>
         /// Gets the depth mode that is currently being used (zero to one or minus one to one).
         /// </summary>
diff --git a/Ryujinx.Graphics.Gpu/Engine/Threed/ThreedClass.cs b/Ryujinx.Graphics.Gpu/Engine/Threed/ThreedClass.cs
index 106a6f3f4c..b254e95e0c 100644
--- a/Ryujinx.Graphics.Gpu/Engine/Threed/ThreedClass.cs
+++ b/Ryujinx.Graphics.Gpu/Engine/Threed/ThreedClass.cs
@@ -67,12 +67,13 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             _i2mClass = new InlineToMemoryClass(context, channel, initializeState: false);
 
+            var spec = new SpecializationStateUpdater();
             var drawState = new DrawState();
 
-            _drawManager = new DrawManager(context, channel, _state, drawState);
+            _drawManager = new DrawManager(context, channel, _state, drawState, spec);
             _semaphoreUpdater = new SemaphoreUpdater(context, channel, _state);
             _cbUpdater = new ConstantBufferUpdater(channel, _state);
-            _stateUpdater = new StateUpdater(context, channel, _state, drawState);
+            _stateUpdater = new StateUpdater(context, channel, _state, drawState, spec);
 
             // This defaults to "always", even without any register write.
             // Reads just return 0, regardless of what was set there.
diff --git a/Ryujinx.Graphics.Gpu/Shader/GpuChannelGraphicsState.cs b/Ryujinx.Graphics.Gpu/Shader/GpuChannelGraphicsState.cs
index 511f4c2359..e5e4862602 100644
--- a/Ryujinx.Graphics.Gpu/Shader/GpuChannelGraphicsState.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/GpuChannelGraphicsState.cs
@@ -15,62 +15,62 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <summary>
         /// Early Z force enable.
         /// </summary>
-        public readonly bool EarlyZForce;
+        public bool EarlyZForce;
 
         /// <summary>
         /// Primitive topology of current draw.
         /// </summary>
-        public readonly PrimitiveTopology Topology;
+        public PrimitiveTopology Topology;
 
         /// <summary>
         /// Tessellation mode.
         /// </summary>
-        public readonly TessMode TessellationMode;
+        public TessMode TessellationMode;
 
         /// <summary>
         /// Indicates whether alpha-to-coverage is enabled.
         /// </summary>
-        public readonly bool AlphaToCoverageEnable;
+        public bool AlphaToCoverageEnable;
 
         /// <summary>
         /// Indicates whether alpha-to-coverage dithering is enabled.
         /// </summary>
-        public readonly bool AlphaToCoverageDitherEnable;
+        public bool AlphaToCoverageDitherEnable;
 
         /// <summary>
         /// Indicates whether the viewport transform is disabled.
         /// </summary>
-        public readonly bool ViewportTransformDisable;
+        public bool ViewportTransformDisable;
 
         /// <summary>
         /// Depth mode zero to one or minus one to one.
         /// </summary>
-        public readonly bool DepthMode;
+        public bool DepthMode;
 
         /// <summary>
         /// Indicates if the point size is set on the shader or is fixed.
         /// </summary>
-        public readonly bool ProgramPointSizeEnable;
+        public bool ProgramPointSizeEnable;
 
         /// <summary>
         /// Point size used if <see cref="ProgramPointSizeEnable" /> is false.
         /// </summary>
-        public readonly float PointSize;
+        public float PointSize;
 
         /// <summary>
         /// Indicates whether alpha test is enabled.
         /// </summary>
-        public readonly bool AlphaTestEnable;
+        public bool AlphaTestEnable;
 
         /// <summary>
         /// When alpha test is enabled, indicates the comparison that decides if the fragment should be discarded.
         /// </summary>
-        public readonly CompareOp AlphaTestCompare;
+        public CompareOp AlphaTestCompare;
 
         /// <summary>
         /// When alpha test is enabled, indicates the value to compare with the fragment output alpha.
         /// </summary>
-        public readonly float AlphaTestReference;
+        public float AlphaTestReference;
 
         /// <summary>
         /// Type of the vertex attributes consumed by the shader.
@@ -80,12 +80,12 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <summary>
         /// Indicates that the draw is writing the base vertex, base instance and draw index to Constant Buffer 0.
         /// </summary>
-        public readonly bool HasConstantBufferDrawParameters;
+        public bool HasConstantBufferDrawParameters;
 
         /// <summary>
         /// Indicates that any storage buffer use is unaligned.
         /// </summary>
-        public readonly bool HasUnalignedStorageBuffer;
+        public bool HasUnalignedStorageBuffer;
 
         /// <summary>
         /// Creates a new GPU graphics state.
diff --git a/Ryujinx.Graphics.Gpu/Shader/GpuChannelPoolState.cs b/Ryujinx.Graphics.Gpu/Shader/GpuChannelPoolState.cs
index 0b36227ac9..b894c57e78 100644
--- a/Ryujinx.Graphics.Gpu/Shader/GpuChannelPoolState.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/GpuChannelPoolState.cs
@@ -1,9 +1,11 @@
+using System;
+
 namespace Ryujinx.Graphics.Gpu.Shader
 {
     /// <summary>
     /// State used by the <see cref="GpuAccessor"/>.
     /// </summary>
-    struct GpuChannelPoolState
+    struct GpuChannelPoolState : IEquatable<GpuChannelPoolState>
     {
         /// <summary>
         /// GPU virtual address of the texture pool.
@@ -32,5 +34,17 @@ namespace Ryujinx.Graphics.Gpu.Shader
             TexturePoolMaximumId = texturePoolMaximumId;
             TextureBufferIndex = textureBufferIndex;
         }
+
+        /// <summary>
+        /// Check if the pool states are equal.
+        /// </summary>
+        /// <param name="other">Pool state to compare with</param>
+        /// <returns>True if they are equal, false otherwise</returns>
+        public bool Equals(GpuChannelPoolState other)
+        {
+            return TexturePoolGpuVa == other.TexturePoolGpuVa &&
+                TexturePoolMaximumId == other.TexturePoolMaximumId &&
+                TextureBufferIndex == other.TextureBufferIndex;
+        }
     }
 }
\ No newline at end of file
diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
index 3eaab79f05..23b213b4b4 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderCache.cs
@@ -300,16 +300,16 @@ namespace Ryujinx.Graphics.Gpu.Shader
             ref ThreedClassState state,
             ref ProgramPipelineState pipeline,
             GpuChannel channel,
-            GpuChannelPoolState poolState,
-            GpuChannelGraphicsState graphicsState,
+            ref GpuChannelPoolState poolState,
+            ref GpuChannelGraphicsState graphicsState,
             ShaderAddresses addresses)
         {
-            if (_gpPrograms.TryGetValue(addresses, out var gpShaders) && IsShaderEqual(channel, poolState, graphicsState, gpShaders, addresses))
+            if (_gpPrograms.TryGetValue(addresses, out var gpShaders) && IsShaderEqual(channel, ref poolState, ref graphicsState, gpShaders, addresses))
             {
                 return gpShaders;
             }
 
-            if (_graphicsShaderCache.TryFind(channel, poolState, graphicsState, addresses, out gpShaders, out var cachedGuestCode))
+            if (_graphicsShaderCache.TryFind(channel, ref poolState, ref graphicsState, addresses, out gpShaders, out var cachedGuestCode))
             {
                 _gpPrograms[addresses] = gpShaders;
                 return gpShaders;
@@ -498,7 +498,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
         {
             if (IsShaderEqual(channel.MemoryManager, cpShader.Shaders[0], gpuVa))
             {
-                return cpShader.SpecializationState.MatchesCompute(channel, poolState, computeState, true);
+                return cpShader.SpecializationState.MatchesCompute(channel, ref poolState, computeState, true);
             }
 
             return false;
@@ -515,8 +515,8 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <returns>True if the code is different, false otherwise</returns>
         private static bool IsShaderEqual(
             GpuChannel channel,
-            GpuChannelPoolState poolState,
-            GpuChannelGraphicsState graphicsState,
+            ref GpuChannelPoolState poolState,
+            ref GpuChannelGraphicsState graphicsState,
             CachedShaderProgram gpShaders,
             ShaderAddresses addresses)
         {
@@ -536,7 +536,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
 
             bool usesDrawParameters = gpShaders.Shaders[1]?.Info.UsesDrawParameters ?? false;
 
-            return gpShaders.SpecializationState.MatchesGraphics(channel, poolState, graphicsState, usesDrawParameters, true);
+            return gpShaders.SpecializationState.MatchesGraphics(channel, ref poolState, ref graphicsState, usesDrawParameters, true);
         }
 
         /// <summary>
diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderCacheHashTable.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderCacheHashTable.cs
index 3d74e53a10..e35c06b133 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderCacheHashTable.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderCacheHashTable.cs
@@ -215,8 +215,8 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <returns>True if a cached host program was found, false otherwise</returns>
         public bool TryFind(
             GpuChannel channel,
-            GpuChannelPoolState poolState,
-            GpuChannelGraphicsState graphicsState,
+            ref GpuChannelPoolState poolState,
+            ref GpuChannelGraphicsState graphicsState,
             ShaderAddresses addresses,
             out CachedShaderProgram program,
             out CachedGraphicsGuestCode guestCode)
@@ -236,7 +236,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
 
             if (found && _shaderPrograms.TryGetValue(idTable, out ShaderSpecializationList specList))
             {
-                return specList.TryFindForGraphics(channel, poolState, graphicsState, out program);
+                return specList.TryFindForGraphics(channel, ref poolState, ref graphicsState, out program);
             }
 
             return false;
diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationList.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationList.cs
index cb6ab49a81..7d61332e57 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationList.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationList.cs
@@ -29,15 +29,15 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <returns>True if a compatible program is found, false otherwise</returns>
         public bool TryFindForGraphics(
             GpuChannel channel,
-            GpuChannelPoolState poolState,
-            GpuChannelGraphicsState graphicsState,
+            ref GpuChannelPoolState poolState,
+            ref GpuChannelGraphicsState graphicsState,
             out CachedShaderProgram program)
         {
             foreach (var entry in _entries)
             {
                 bool usesDrawParameters = entry.Shaders[1]?.Info.UsesDrawParameters ?? false;
 
-                if (entry.SpecializationState.MatchesGraphics(channel, poolState, graphicsState, usesDrawParameters, true))
+                if (entry.SpecializationState.MatchesGraphics(channel, ref poolState, ref graphicsState, usesDrawParameters, true))
                 {
                     program = entry;
                     return true;
@@ -60,7 +60,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
         {
             foreach (var entry in _entries)
             {
-                if (entry.SpecializationState.MatchesCompute(channel, poolState, computeState, true))
+                if (entry.SpecializationState.MatchesCompute(channel, ref poolState, computeState, true))
                 {
                     program = entry;
                     return true;
diff --git a/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationState.cs b/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationState.cs
index 8f931507aa..14f64bbf40 100644
--- a/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationState.cs
+++ b/Ryujinx.Graphics.Gpu/Shader/ShaderSpecializationState.cs
@@ -392,6 +392,15 @@ namespace Ryujinx.Graphics.Gpu.Shader
             state.Value.QueriedFlags |= QueriedTextureStateFlags.CoordNormalized;
         }
 
+        /// <summary>
+        /// Checks if primitive topology was queried by the shader.
+        /// </summary>
+        /// <returns>True if queried, false otherwise</returns>
+        public bool IsPrimitiveTopologyQueried()
+        {
+            return _queriedState.HasFlag(QueriedStateFlags.PrimitiveTopology);
+        }
+
         /// <summary>
         /// Checks if a given texture was registerd on this specialization state.
         /// </summary>
@@ -486,8 +495,8 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <returns>True if the state matches, false otherwise</returns>
         public bool MatchesGraphics(
             GpuChannel channel,
-            GpuChannelPoolState poolState,
-            GpuChannelGraphicsState graphicsState,
+            ref GpuChannelPoolState poolState,
+            ref GpuChannelGraphicsState graphicsState,
             bool usesDrawParameters,
             bool checkTextures)
         {
@@ -536,7 +545,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
                 return false;
             }
 
-            return Matches(channel, poolState, checkTextures, isCompute: false);
+            return Matches(channel, ref poolState, checkTextures, isCompute: false);
         }
 
         /// <summary>
@@ -547,14 +556,14 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <param name="computeState">Compute state</param>
         /// <param name="checkTextures">Indicates whether texture descriptors should be checked</param>
         /// <returns>True if the state matches, false otherwise</returns>
-        public bool MatchesCompute(GpuChannel channel, GpuChannelPoolState poolState, GpuChannelComputeState computeState, bool checkTextures)
+        public bool MatchesCompute(GpuChannel channel, ref GpuChannelPoolState poolState, GpuChannelComputeState computeState, bool checkTextures)
         {
             if (computeState.HasUnalignedStorageBuffer != ComputeState.HasUnalignedStorageBuffer)
             {
                 return false;
             }
 
-            return Matches(channel, poolState, checkTextures, isCompute: true);
+            return Matches(channel, ref poolState, checkTextures, isCompute: true);
         }
 
         /// <summary>
@@ -618,7 +627,7 @@ namespace Ryujinx.Graphics.Gpu.Shader
         /// <param name="checkTextures">Indicates whether texture descriptors should be checked</param>
         /// <param name="isCompute">Indicates whenever the check is requested by the 3D or compute engine</param>
         /// <returns>True if the state matches, false otherwise</returns>
-        private bool Matches(GpuChannel channel, GpuChannelPoolState poolState, bool checkTextures, bool isCompute)
+        private bool Matches(GpuChannel channel, ref GpuChannelPoolState poolState, bool checkTextures, bool isCompute)
         {
             int constantBufferUsePerStageMask = _constantBufferUsePerStage;