HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_warp_sync_functions.h
1/*
2Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
23#pragma once
24
25// Warp sync builtins (with explicit mask argument) introduced in ROCm 6.2 as a
26// preview to allow end-users to adapt to the new interface involving 64-bit
27// masks. These are disabled by default, and can be enabled by setting the macro
28// "HIP_ENABLE_WARP_SYNC_BUILTINS". This arrangement also applies to the
29// __activemask() builtin defined in amd_warp_functions.h.
30#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
31
32#if !defined(__HIPCC_RTC__)
33#include "amd_warp_functions.h"
34#include "hip_assert.h"
35#endif
36
37template <typename T>
38__device__ inline
39T __hip_readfirstlane(T val) {
40 // In theory, behaviour is undefined when reading from a union member other
41 // than the member that was last assigned to, but it works in practice because
42 // we rely on the compiler to do the reasonable thing.
43 union {
44 unsigned long long l;
45 T d;
46 } u;
47 u.d = val;
48 // NOTE: The builtin returns int, so we first cast it to unsigned int and only
49 // then extend it to 64 bits.
50 unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l);
51 unsigned long long upper =
52 (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32);
53 u.l = (upper << 32) | lower;
54 return u.d;
55}
56
57// When compiling for wave32 mode, ignore the upper half of the 64-bit mask.
58#define __hip_adjust_mask_for_wave32(MASK) \
59 do { \
60 if (warpSize == 32) MASK &= 0xFFFFFFFF; \
61 } while (0)
62
63// We use a macro to expand each builtin into a waterfall that implements the
64// mask semantics:
65//
66// 1. The mask argument may be divergent.
67// 2. Each active thread must have its own bit set in its own mask value.
68// 3. For a given mask value, all threads that are mentioned in the mask must
69// execute the same static instance of the builtin with the same mask.
70// 4. The union of all mask values supplied at a static instance must be equal
71// to the activemask at the program point.
72//
73// Thus, the mask argument partitions the set of currently active threads in the
74// wave into disjoint subsets that cover all active threads.
75//
76// Implementation notes:
77// ---------------------
78//
79// We implement this as a waterfall loop that executes the builtin for each
80// subset separately. The return value is a divergent value across the active
81// threads. The value for inactive threads is defined by each builtin
82// separately.
83//
84// As long as every mask value is non-zero, we don't need to check if a lane
85// specifies itself in the mask; that is done by the later assertion where all
86// chosen lanes must be in the chosen mask.
87
88#define __hip_check_mask(MASK) \
89 do { \
90 __hip_assert(MASK && "mask must be non-zero"); \
91 bool done = false; \
92 while (__any(!done)) { \
93 if (!done) { \
94 auto chosen_mask = __hip_readfirstlane(MASK); \
95 if (MASK == chosen_mask) { \
96 __hip_assert(MASK == __ballot(true) && \
97 "all threads specified in the mask" \
98 " must execute the same operation with the same mask"); \
99 done = true; \
100 } \
101 } \
102 } \
103 } while(0)
104
105#define __hip_do_sync(RETVAL, FUNC, MASK, ...) \
106 do { \
107 __hip_assert(MASK && "mask must be non-zero"); \
108 bool done = false; \
109 while (__any(!done)) { \
110 if (!done) { \
111 auto chosen_mask = __hip_readfirstlane(MASK); \
112 if (MASK == chosen_mask) { \
113 __hip_assert(MASK == __ballot(true) && \
114 "all threads specified in the mask" \
115 " must execute the same operation with the same mask"); \
116 RETVAL = FUNC(__VA_ARGS__); \
117 done = true; \
118 } \
119 } \
120 } \
121 } while(0)
122
123// __all_sync, __any_sync, __ballot_sync
124
125template <typename MaskT>
126__device__ inline
127unsigned long long __ballot_sync(MaskT mask, int predicate) {
128 static_assert(
129 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
130 "The mask must be a 64-bit integer. "
131 "Implicitly promoting a smaller integer is almost always an error.");
132 __hip_adjust_mask_for_wave32(mask);
133 __hip_check_mask(mask);
134 return __ballot(predicate) & mask;
135}
136
137template <typename MaskT>
138__device__ inline
139int __all_sync(MaskT mask, int predicate) {
140 static_assert(
141 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
142 "The mask must be a 64-bit integer. "
143 "Implicitly promoting a smaller integer is almost always an error.");
144 __hip_adjust_mask_for_wave32(mask);
145 return __ballot_sync(mask, predicate) == mask;
146}
147
148template <typename MaskT>
149__device__ inline
150int __any_sync(MaskT mask, int predicate) {
151 static_assert(
152 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
153 "The mask must be a 64-bit integer. "
154 "Implicitly promoting a smaller integer is almost always an error.");
155 __hip_adjust_mask_for_wave32(mask);
156 return __ballot_sync(mask, predicate) != 0;
157}
158
159// __match_any, __match_all and sync variants
160
161template <typename T>
162__device__ inline
163unsigned long long __match_any(T value) {
164 static_assert(
165 (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
166 (sizeof(T) == 4 || sizeof(T) == 8),
167 "T can be int, unsigned int, long, unsigned long, long long, unsigned "
168 "long long, float or double.");
169 bool done = false;
170 unsigned long long retval = 0;
171
172 while (__any(!done)) {
173 if (!done) {
174 T chosen = __hip_readfirstlane(value);
175 if (chosen == value) {
176 retval = __activemask();
177 done = true;
178 }
179 }
180 }
181
182 return retval;
183}
184
185template <typename MaskT, typename T>
186__device__ inline
187unsigned long long __match_any_sync(MaskT mask, T value) {
188 static_assert(
189 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
190 "The mask must be a 64-bit integer. "
191 "Implicitly promoting a smaller integer is almost always an error.");
192 __hip_adjust_mask_for_wave32(mask);
193 __hip_check_mask(mask);
194 return __match_any(value) & mask;
195}
196
197template <typename T>
198__device__ inline
199unsigned long long __match_all(T value, int* pred) {
200 static_assert(
201 (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
202 (sizeof(T) == 4 || sizeof(T) == 8),
203 "T can be int, unsigned int, long, unsigned long, long long, unsigned "
204 "long long, float or double.");
205 T first = __hip_readfirstlane(value);
206 if (__all(first == value)) {
207 *pred = true;
208 return __activemask();
209 } else {
210 *pred = false;
211 return 0;
212 }
213}
214
215template <typename MaskT, typename T>
216__device__ inline
217unsigned long long __match_all_sync(MaskT mask, T value, int* pred) {
218 static_assert(
219 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
220 "The mask must be a 64-bit integer. "
221 "Implicitly promoting a smaller integer is almost always an error.");
222 MaskT retval = 0;
223 __hip_adjust_mask_for_wave32(mask);
224 __hip_do_sync(retval, __match_all, mask, value, pred);
225 return retval;
226}
227
228// various variants of shfl
229
230template <typename MaskT, typename T>
231__device__ inline
232T __shfl_sync(MaskT mask, T var, int srcLane,
233 int width = __AMDGCN_WAVEFRONT_SIZE) {
234 static_assert(
235 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
236 "The mask must be a 64-bit integer. "
237 "Implicitly promoting a smaller integer is almost always an error.");
238 __hip_adjust_mask_for_wave32(mask);
239 __hip_check_mask(mask);
240 return __shfl(var, srcLane, width);
241}
242
243template <typename MaskT, typename T>
244__device__ inline
245T __shfl_up_sync(MaskT mask, T var, unsigned int delta,
246 int width = __AMDGCN_WAVEFRONT_SIZE) {
247 static_assert(
248 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
249 "The mask must be a 64-bit integer. "
250 "Implicitly promoting a smaller integer is almost always an error.");
251 __hip_adjust_mask_for_wave32(mask);
252 __hip_check_mask(mask);
253 return __shfl_up(var, delta, width);
254}
255
256template <typename MaskT, typename T>
257__device__ inline
258T __shfl_down_sync(MaskT mask, T var, unsigned int delta,
259 int width = __AMDGCN_WAVEFRONT_SIZE) {
260 static_assert(
261 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
262 "The mask must be a 64-bit integer. "
263 "Implicitly promoting a smaller integer is almost always an error.");
264 __hip_adjust_mask_for_wave32(mask);
265 __hip_check_mask(mask);
266 return __shfl_down(var, delta, width);
267}
268
269template <typename MaskT, typename T>
270__device__ inline
271T __shfl_xor_sync(MaskT mask, T var, int laneMask,
272 int width = __AMDGCN_WAVEFRONT_SIZE) {
273 static_assert(
274 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
275 "The mask must be a 64-bit integer. "
276 "Implicitly promoting a smaller integer is almost always an error.");
277 __hip_adjust_mask_for_wave32(mask);
278 __hip_check_mask(mask);
279 return __shfl_xor(var, laneMask, width);
280}
281
282#undef __hip_do_sync
283#undef __hip_check_mask
284#undef __hip_adjust_mask_for_wave32
285
286#endif // HIP_ENABLE_WARP_SYNC_BUILTINS