1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
From: Cordell Bloor <cgmb@debian.org>
Date: Mon, 3 Nov 2025 01:36:27 -0700
Subject: fix asm thread load store
Backport ROCm 7.1 ASM thread load/store behaviours.
Applied-Upstream: 7.1.0
Forwarded: not-needed
---
rocprim/include/rocprim/thread/thread_load.hpp | 38 +++++++------------------
rocprim/include/rocprim/thread/thread_store.hpp | 35 ++++++++---------------
2 files changed, 23 insertions(+), 50 deletions(-)
diff --git a/rocprim/include/rocprim/thread/thread_load.hpp b/rocprim/include/rocprim/thread/thread_load.hpp
index 17348d3..107a692 100644
--- a/rocprim/include/rocprim/thread/thread_load.hpp
+++ b/rocprim/include/rocprim/thread/thread_load.hpp
@@ -90,42 +90,26 @@ ROCPRIM_DEVICE __forceinline__ T AsmThreadLoad(void * ptr)
return *bit_cast<type*>(&retval); \
}
-// TODO Add specialization for custom larger data types
+ // TODO Add specialization for custom larger data types
+ // clang-format off
#define ROCPRIM_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \
- ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_inst, wait_cmd); \
- ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_inst, wait_cmd); \
- ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_inst, wait_cmd); \
- ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_inst, wait_cmd); \
+ ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int32_t, flat_load_sbyte, v, wait_inst, wait_cmd); \
+ ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int32_t, flat_load_sshort, v, wait_inst, wait_cmd); \
+ ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint32_t, flat_load_ubyte, v, wait_inst, wait_cmd); \
+ ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint32_t, flat_load_ushort, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd);
+ // clang-format on
-#if defined(__gfx940__) || defined(__gfx941__)
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "sc0", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "sc1", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cv, "sc0 sc1", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "sc0 sc1", "s_waitcnt", "vmcnt");
-#elif defined(__gfx942__)
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "sc0", "s_waitcnt", "");
+ #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "sc0 nt", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cv, "sc0", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "sc0", "s_waitcnt", "vmcnt");
-#elif defined(__gfx1200__) || defined(__gfx1201__)
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
+ #elif defined(__GFX12__)
ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cv, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
-#else
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "glc", "s_waitcnt", "");
+ #else
ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "glc slc", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cv, "glc", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "glc", "s_waitcnt", "vmcnt");
-#endif
-
-// TODO find correct modifiers to match these
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ldg, "", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cs, "", "s_waitcnt", "");
+ #endif
#endif
diff --git a/rocprim/include/rocprim/thread/thread_store.hpp b/rocprim/include/rocprim/thread/thread_store.hpp
index 7f8ba76..740ec35 100644
--- a/rocprim/include/rocprim/thread/thread_store.hpp
+++ b/rocprim/include/rocprim/thread/thread_store.hpp
@@ -91,8 +91,9 @@ ROCPRIM_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val)
: "v"(ptr), #output_modifier(temp_val), "I"(0x00)); \
}
-// TODO fix flat_store_ubyte and flat_store_sbyte issues
-// TODO Add specialization for custom larger data types
+ // TODO fix flat_store_ubyte and flat_store_sbyte issues
+ // TODO Add specialization for custom larger data types
+ // clang-format off
#define ROCPRIM_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \
ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_inst, wait_cmd); \
@@ -102,30 +103,18 @@ ROCPRIM_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val)
ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); \
ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd);
+ // clang-format on
-#if defined(__gfx940__) || defined(__gfx941__)
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "sc0 sc1", "s_waitcnt", ""); // TODO: gfx942 validation
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "sc0 sc1", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wt, "sc0 sc1", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "sc0 sc1", "s_waitcnt", "vmcnt");
-#elif defined(__gfx942__)
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "sc0", "s_waitcnt", "");
+ #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "sc0 nt", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wt, "sc0", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "sc0", "s_waitcnt", "vmcnt");
-#elif defined(__gfx1200__) || defined(__gfx1201__)
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", ""); // TODO: gfx942 validation
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wt, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
-#else
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "glc", "s_waitcnt", "");
+ #elif defined(__GFX12__)
+ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg,
+ "th:TH_DEFAULT scope:SCOPE_DEV",
+ "s_wait_storecnt_dscnt",
+ "");
+ #else
ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "glc slc", "s_waitcnt", "");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_wt, "glc", "s_waitcnt", "vmcnt");
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "glc", "s_waitcnt", "vmcnt");
-#endif
-// TODO find correct modifiers to match these
-ROCPRIM_ASM_THREAD_STORE_GROUP(store_cs, "", "s_waitcnt", "");
+ #endif
#endif
|