From 1ae3479e42a2f52ef07410b16d963ccdae8b62f7 Mon Sep 17 00:00:00 2001
From: drisspg <drisspguessous@gmail.com>
Date: Mon, 16 Dec 2024 09:15:07 -0800
Subject: [PATCH 1/5] Update

[ghstack-poisoned]
---
 test/inductor/test_flex_attention.py     | 5 +++++
 torch/_inductor/kernel/flex_attention.py | 3 ---
 2 files changed, 5 insertions(+), 3 deletions(-)

Index: pytorch/test/inductor/test_flex_attention.py
===================================================================
--- pytorch.orig/test/inductor/test_flex_attention.py
+++ pytorch/test/inductor/test_flex_attention.py
@@ -3231,6 +3231,11 @@ def forward(self, arg0_1, arg1_1, arg2_1
 
         self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S)
 
+    @supported_platform
+    def test_num_warps_8_error(self):
+        attention = functools.partial(flex_attention, score_mod=_identity)
+        self.run_test_with_call(attention, Q_S=128, KV_S=128, Q_D=128, V_D=128)
+
     @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
     def test_qkv_and_block_mask_on_the_same_device(self):
         make_tensor = functools.partial(
Index: pytorch/torch/_inductor/kernel/flex_attention.py
===================================================================
--- pytorch.orig/torch/_inductor/kernel/flex_attention.py
+++ pytorch/torch/_inductor/kernel/flex_attention.py
@@ -746,11 +746,13 @@ def _get_nv_config(query, mode: Mode) ->
                 return (64, 128, 8, 3)
             else:
                 return (64, 64, 4, 2)
-        elif capability >= (8, 0):  # A100
-            if head_dim == 64:
+        elif capability >= (8, 0):
+            if head_dim >= 64:
                 return (32, 128, 4, 3)
             elif head_dim == 128:
-                return (64, 128, 8, 3)
+                # SM86/89 have smaller shared memory sizes
+                num_stages = 3 if capability[-1] == 0 else 2
+                return (64, 64, 4, num_stages)
             else:
                 return (64, 64, 4, 2)
         else:  # modest hardware or extremely large head_dim
@@ -2273,9 +2275,6 @@ def flex_attention_backward(*args, **kwa
             or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
         ):
             continue
-        if num_warps == 8:
-            # Working around https://github.com/pytorch/pytorch/issues/141603
-            continue
 
         # Performance tuning
         cur_kernel_options = original_kernel_options.copy()
