container_helper.hpp Source File

container_helper.hpp Source File#

Composable Kernel: container_helper.hpp Source File
tile/core/container/container_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck_tile {
14
15template <typename TData, index_t NSize>
16CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
17{
19 static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
20 r[number<NSize>{}] = x;
21 return r;
22}
23
24template <typename... Ts, typename T>
25CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
26{
27 return container_concat(make_tuple(x), a);
28}
29
30template <typename... Ts, typename T>
31CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
32{
33 return container_concat(a, make_tuple(x));
34}
35
36// reorder array
37template <typename TData, index_t NSize, index_t... IRs>
38CK_TILE_HOST_DEVICE constexpr auto
40{
41 static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
42 static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
43 return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
44}
45
46template <typename TData, index_t NSize, index_t... IRs>
47CK_TILE_HOST_DEVICE constexpr auto
49{
51 old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
52}
53
54// reorder array
55template <typename TData, index_t NSize>
56CK_TILE_HOST_DEVICE constexpr auto
58 const map<index_t, index_t>& new2old)
59{
60 array<TData, NSize> new_array;
61
62 for(const auto& [new_pos, old_pos] : new2old)
63 {
64 new_array(new_pos) = old_array[old_pos];
65 }
66
67 return new_array;
68}
69
70template <typename TData, index_t NSize>
71CK_TILE_HOST_DEVICE constexpr auto
73 const map<index_t, index_t>& old2new)
74{
75 array<TData, NSize> new_array;
76
77 for(const auto& [old_pos, new_pos] : old2new)
78 {
79 new_array(new_pos) = old_array[old_pos];
80 }
81
82 return new_array;
83}
84
85// reorder tuple
86template <typename... Ts, index_t... IRs>
88 sequence<IRs...> /*new2old*/)
89{
90 static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
91
92 static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
93
94 return make_tuple(old_tuple[number<IRs>{}]...);
95}
96
97template <typename... Ts, index_t... IRs>
99 sequence<IRs...> old2new)
100{
102 old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
103}
104
105// reorder sequence
106template <index_t... Is, index_t... IRs>
108 sequence<IRs...> /*new2old*/)
109{
110 static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
111
112 static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
113
115}
116
117template <index_t... Is, index_t... IRs>
119 sequence<IRs...> /* old2new */)
120{
121 static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
122
123 static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
124
125 constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
126
127 return container_reorder_given_new2old(old_seq, new2old);
128}
129
130#if 0
131// rocm-4.1 compiler would crash for recursive lambda
132template <typename Container,
133 typename Reduce,
134 typename Init,
135 index_t IBegin = 0,
136 index_t IEnd = Container::size(),
137 index_t IStep = 1>
138CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
139 Reduce reduce,
140 Init init,
142 number<IEnd> = number<Container::size()>{},
144{
145 static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
146
147 // f is recursive function, fs is a dummy of f
148 // i is index, y_old is current scan, r_old is current reduction
149 auto f = [&](auto fs, auto i, auto r_old) {
150 auto r_new = reduce(x[i], r_old);
151
152 if constexpr(i.value < IEnd - IStep)
153 {
154 // recursively call f/fs
155 return fs(fs, i + number<IStep>{}, r_new);
156 }
157 else
158 {
159 return r_new;
160 }
161 };
162
163 // start recursion
164 return f(f, number<IBegin>{}, init);
165}
166#else
167// i is index, y_old is current scan, r_old is current reduction
168template <typename Container,
169 typename Reduce,
170 typename ROld,
171 index_t I,
172 index_t IEnd,
173 index_t IStep>
175 const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
176{
177 auto r_new = reduce(x[i], r_old);
178
179 if constexpr(i.value < IEnd - IStep)
180 {
182 x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
183 }
184 else
185 {
186 return r_new;
187 }
188}
189
190// rocm-4.1 compiler would crash for recursive lambda
191// container reduce with initial value
192template <typename Container,
193 typename Reduce,
194 typename Init,
195 index_t IBegin = 0,
196 index_t IEnd = Container::size(),
197 index_t IStep = 1>
198CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
199 Reduce reduce,
200 Init init,
202 number<IEnd> = number<Container::size()>{},
204{
205 static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
206
207 if constexpr(IEnd > IBegin)
208 {
210 x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
211 }
212 else
213 {
214 return init;
215 }
216}
217#endif
218
219template <typename TData, index_t NSize, typename Reduce>
220CK_TILE_HOST_DEVICE constexpr auto
222{
224
225 TData r = init;
226
227 static_for<NSize - 1, 0, -1>{}([&](auto i) {
228 r = f(r, x[i]);
229 y(i) = r;
230 });
231
232 r = f(r, x[number<0>{}]);
233 y(number<0>{}) = r;
234
235 return y;
236}
237
238template <typename TData, index_t NSize, typename Reduce, typename Init>
239CK_TILE_HOST_DEVICE constexpr auto
241{
242#if 0
244
245 TData r = init;
246
247 static_for<NSize - 1, 0, -1>{}([&](auto i) {
248 y(i) = r;
249 r = f(r, x[i]);
250 });
251
252 y(number<0>{}) = r;
253
254 return y;
255#else
257
258 TData r = init;
259
260 for(index_t i = NSize - 1; i > 0; --i)
261 {
262 y(i) = r;
263 r = f(r, x[i]);
264 }
265
266 y(0) = r;
267
268 return y;
269#endif
270}
271
272template <index_t... Is, typename Reduce, index_t Init>
273CK_TILE_HOST_DEVICE constexpr auto
278
279#if 0
280// rocm4.1 compiler would crash with recursive lambda
281template <typename... Xs, typename Reduce, typename Init>
282CK_TILE_HOST_DEVICE constexpr auto
283container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
284{
285 constexpr index_t NSize = sizeof...(Xs);
286
287 // f is recursive function, fs is a dummy of f
288 // i is index, y_old is current scan, r_old is current reduction
289 auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
290 auto r_new = reduce(x[i], r_old);
291
292 auto y_new = container_push_front(y_old, r_new);
293
294 if constexpr(i.value > 1)
295 {
296 // recursively call f/fs
297 return fs(fs, i - number<1>{}, y_new, r_new);
298 }
299 else
300 {
301 return y_new;
302 }
303 };
304
305 // start recursion
306 return f(f, number<NSize - 1>{}, make_tuple(init), init);
307}
308#else
309// i is index, y_old is current scan, r_old is current reduction
310template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
312 const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
313{
314 auto r_new = reduce(x[i], r_old);
315
316 auto y_new = container_push_front(y_old, r_new);
317
318 if constexpr(i.value > 1)
319 {
320 // recursively call f/fs
321 return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
322 }
323 else
324 {
325 return y_new;
326 }
327}
328
329template <typename... Xs, typename Reduce, typename Init>
330CK_TILE_HOST_DEVICE constexpr auto
332{
333 constexpr index_t NSize = sizeof...(Xs);
334
336 x, reduce, number<NSize - 1>{}, make_tuple(init), init);
337}
338#endif
339
340// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
341template <typename... Xs, typename Reduce, typename TData>
342CK_TILE_HOST_DEVICE constexpr auto
344{
345 constexpr index_t NSize = sizeof...(Xs);
346
347 tuple<Xs...> y;
348
349 TData r = init;
350
351 static_for<NSize - 1, 0, -1>{}([&](auto i) {
352 r = f(r, x[i]);
353 y(i) = r;
354 });
355
356 r = f(r, x[number<0>{}]);
357 y(number<0>{}) = r;
358
359 return y;
360}
361
362template <typename X, typename... Ys>
363CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
364{
365 return container_concat(x, container_concat(ys...));
366}
367
368template <typename T, index_t NX, index_t NY>
370{
371 return unpack2(
372 [&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
373}
374
375template <typename... X, typename... Y>
377{
378 return unpack2(
379 [&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
380}
381
382template <typename Container>
383CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
384{
385 return x;
386}
387
388template <typename T, index_t N, index_t... Is>
390{
391 static_assert(N >= sizeof...(Is), "wrong! size");
392
393 if constexpr(sizeof...(Is) > 0)
394 {
395 return make_array<T>(arr[Is]...);
396 }
397 else
398 {
399 return array<T, 0>{};
400 }
401}
402
403template <typename... Ts, index_t... Is>
405{
406 static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
407
408 if constexpr(sizeof...(Is) > 0)
409 {
410 return make_tuple(tup[number<Is>{}]...);
411 }
412 else
413 {
414 return tuple<>{};
415 }
416}
417
418template <typename T, index_t N, index_t... Is>
419CK_TILE_HOST_DEVICE constexpr void
420set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
421{
422 static_assert(N >= sizeof...(Is), "wrong! size");
423
424 if constexpr(sizeof...(Is) > 0)
425 {
426 for(index_t i = 0; i < picks.size(); ++i)
427 {
428 y(picks[i]) = x[i];
429 }
430 }
431}
432
433template <typename Y, typename X, index_t... Is>
434CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
435{
436 static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
437
438 if constexpr(sizeof...(Is) > 0)
439 {
440 static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
441 }
442}
443
444// return the index of first occurance in the sequence.
445// return seq.size(), if not found
446template <index_t... Is>
448{
449 for(auto i = 0; i < seq.size(); i++)
450 {
451 if(seq[i] == value)
452 return i;
453 }
454
455 return seq.size();
456}
457
458template <index_t... Is>
460{
461 using Seq = sequence<Is...>;
462
463 return generate_tuple(
464 [&](auto i) {
465 constexpr index_t tmp = Seq::at(i);
466 return number<tmp>{};
467 },
468 number<Seq::size()>{});
469}
470
471#if 0
472#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
473 [a_of_b_impl, a_size, bs_sizes] { \
474 return ck_tile::generate_tuple( \
475 [=](auto i) { \
476 constexpr auto b_impl = a_of_b_impl[i]; \
477 constexpr index_t b_size = bs_sizes[i]; \
478 constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
479 return b; \
480 }, \
481 ck_tile::number<a_size>{}); \
482 }()
483#else
484// constexpr index_t can't be captured "-Wunused-lambda-capture"
485// TODO: this is ugly
486#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
487 [a_of_b_impl, bs_sizes] { \
488 return ck_tile::generate_tuple( \
489 [=](auto i) { \
490 constexpr auto b_impl = a_of_b_impl[i]; \
491 constexpr index_t b_size = bs_sizes[i]; \
492 constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
493 return b; \
494 }, \
495 ck_tile::number<a_size>{}); \
496 }()
497#endif
498
499} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition tile/core/utility/functional.hpp:209
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(const tuple< Xs... > &x, Reduce reduce, number< I > i, YOld y_old, ROld r_old)
Definition tile/core/container/container_helper.hpp:311
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition tile/core/container/container_helper.hpp:39
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(const Container &x, Reduce reduce, ROld r_old, number< I > i, number< IEnd >, number< IStep >)
Definition tile/core/container/container_helper.hpp:174
CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, number< Init >)
Definition tile/core/container/sequence.hpp:863
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition tile/core/container/container_helper.hpp:48
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple< Ts... > &a, const T &x)
Definition tile/core/container/container_helper.hpp:25
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence< Is... >)
Definition tile/core/container/container_helper.hpp:459
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto container_reverse_inclusive_scan(const array< TData, NSize > &x, Reduce f, TData init)
Definition tile/core/container/container_helper.hpp:221
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition tile/core/container/container_helper.hpp:447
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array< TData, NSize > &a, const TData &x)
Definition tile/core/container/container_helper.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan(const array< TData, NSize > &x, Reduce f, Init init)
Definition tile/core/container/container_helper.hpp:240
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition reduce2d_kernel.hpp:20
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
static constexpr value_type value
Definition tile/core/numeric/integral_constant.hpp:16
Definition tile/core/container/sequence.hpp:670
Definition map.hpp:16
Definition tile/core/container/sequence.hpp:675
Definition tile/core/container/sequence.hpp:49
static CK_TILE_HOST_DEVICE constexpr auto at()
Definition tile/core/container/sequence.hpp:78
static CK_TILE_HOST_DEVICE constexpr index_t size()
Definition tile/core/container/sequence.hpp:53
Definition tile/core/utility/functional.hpp:43
Definition tile/core/container/tuple.hpp:192