gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File

gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File#

Composable Kernel: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File
gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
16 typename XDataType,
17 typename DyDataType,
18 typename AccDataType,
19 typename ScaleDataType,
20 typename DscaleDbiasDataType,
21 typename MeanVarDataType,
22 typename DyElementwiseOp,
23 typename XYGridDesc_M_K,
24 typename MeanVarGridDesc_M,
25 typename MeanVarCountGridDesc_M_K,
26 typename DscaleDbiasGridDesc_M_G>
28 const XYGridDesc_M_K x_grid_desc_m_k,
29 const XYGridDesc_M_K dy_grid_desc_m_k,
30 const MeanVarGridDesc_M mean_var_grid_desc_m,
31 const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
32 const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
33 index_t blkgroup_size,
34 index_t num_xy_k_block_tile_iteration,
35 index_t num_mean_var_count_k_block_tile_iteration,
36 AccDataType epsilon,
37 bool haveSavedMeanInvVar,
38 const MeanVarDataType* const __restrict__ p_savedMean,
39 const MeanVarDataType* const __restrict__ p_savedInvVar,
40 const MeanVarDataType* const __restrict__ p_in_welford_mean,
41 const MeanVarDataType* const __restrict__ p_in_welford_variance,
42 const int32_t* const __restrict__ p_in_welford_count,
43 const DyElementwiseOp dy_elementwise_op,
44 MeanVarDataType* const __restrict__ p_out_welford_mean,
45 MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
46 const XDataType* const __restrict__ p_x,
47 const DyDataType* const __restrict__ p_dy,
48 DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
49 DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
50{
51 GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
52 dy_grid_desc_m_k,
53 mean_var_grid_desc_m,
54 mean_var_count_grid_desc_m_k,
55 dscale_dbias_grid_desc_m_g,
56 blkgroup_size,
57 num_xy_k_block_tile_iteration,
58 num_mean_var_count_k_block_tile_iteration,
59 epsilon,
60 haveSavedMeanInvVar,
61 p_savedMean,
62 p_savedInvVar,
63 p_in_welford_mean,
64 p_in_welford_variance,
65 p_in_welford_count,
66 dy_elementwise_op,
67 p_out_welford_mean,
68 p_out_welford_inv_variance,
69 p_x,
70 p_dy,
71 p_reduce_dscale,
72 p_reduce_dbias);
73};
74
75template <typename XDataType,
76 typename DyDataType,
77 typename AccDataType,
78 typename ScaleDataType,
79 typename DscaleDbiasDataType,
80 typename MeanVarDataType,
81 typename DyElementwiseOp,
82 typename XYGridDesc_M_K,
83 typename MeanVarGridDesc_M,
84 typename MeanVarCountGridDesc_M_K,
85 typename DscaleDbiasGridDesc_M_G,
86 index_t BlockSize,
87 index_t MThreadClusterSize,
88 index_t KThreadClusterSize,
89 index_t MThreadSliceSize,
90 index_t KThreadSliceSize,
91 index_t XDyVectorDim,
92 index_t XSrcVectorSize,
93 index_t DySrcVectorSize,
94 index_t MeanVarSrcVectorSize>
96{
97 static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
98 MThreadSliceSize % DySrcVectorSize == 0) ||
99 (XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
100 KThreadSliceSize % DySrcVectorSize == 0),
101 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
102
103 static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
104
106
109
112
113 static constexpr auto thread_cluster_desc =
115
122
125
127 BlockSize,
130
132 BlockSize,
136 false>;
137
142 false>;
143
145
146 static constexpr auto I0 = Number<0>{};
147 static constexpr auto I1 = Number<1>{};
148
149 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
150 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
151
152 // clang-format off
153 // Two of the steps of Multiblock BatchNorm Backward
154 // Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
155 // Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
156 // clang-format on
157 __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
158 const XYGridDesc_M_K& dy_grid_desc_m_k,
159 const MeanVarGridDesc_M& mean_var_grid_desc_m,
160 const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
161 const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
162 index_t blkgroup_size,
163 index_t num_xy_k_block_tile_iteration,
164 index_t num_mean_var_count_k_block_tile_iteration,
165 AccDataType epsilon,
166 bool haveSavedMeanInvVar,
167 const MeanVarDataType* const __restrict__ p_savedMean,
168 const MeanVarDataType* const __restrict__ p_savedInvVar,
169 const MeanVarDataType* const __restrict__ p_in_welford_mean,
170 const MeanVarDataType* const __restrict__ p_in_welford_variance,
171 const int32_t* const __restrict__ p_in_welford_count,
172 const DyElementwiseOp dy_elementwise_op,
173 MeanVarDataType* const __restrict__ p_out_welford_mean,
174 MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
175 const XDataType* const __restrict__ p_x,
176 const DyDataType* const __restrict__ p_dy,
177 DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
178 DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
179 {
180 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
181
182 auto reduce_work_buf =
183 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
184
186 in_welford_mean_thread_buf;
188 in_welford_var_thread_buf;
190 in_welford_count_thread_buf;
191
193 welford_mean_thread_buf;
195 welford_var_thread_buf;
197 welford_count_thread_buf;
198
200 welford_mean_thread_buf;
202 inv_var_thread_buf = welford_var_thread_buf;
203
205 x_thread_buf;
207 dy_thread_buf;
208
209 // buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
211 tmp1_thread_buf;
212
214 reduce_dscale_thread_buf;
216 reduce_dbias_thread_buf;
217
218 const index_t thread_local_id = get_thread_local_1d_id();
219 const index_t block_global_id = get_block_1d_id();
220 const index_t blkgroup_id = block_global_id / blkgroup_size;
221 const index_t block_local_id = block_global_id % blkgroup_size;
222
223 const auto thread_cluster_idx =
224 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
225
226 const auto thread_m_cluster_id = thread_cluster_idx[I0];
227 const auto thread_k_cluster_id = thread_cluster_idx[I1];
228
229 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
230 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
231 using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
232 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
234 constexpr auto thread_buffer_desc_m =
236 constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
238
239 // clang-format off
240 // Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
241 // clang-format on
242
243 if(haveSavedMeanInvVar)
244 {
245 const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
246 p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
247
248 const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
249 p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
250
251 auto threadwise_mean_inv_var_load =
252 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
253 AccDataType,
254 MeanVarGridDesc_M,
255 decltype(thread_buffer_desc_m),
256 ThreadBufferLengths_M,
258 0,
259 MeanVarSrcVectorSize,
260 1,
261 true>(
262 mean_var_grid_desc_m,
263 make_multi_index(blkgroup_id * M_BlockTileSize +
264 thread_m_cluster_id * MThreadSliceSize));
265
266 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
267 mean_global_buf,
268 thread_buffer_desc_m,
269 make_tuple(I0),
270 mean_thread_buf);
271
272 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
273 inv_var_global_buf,
274 thread_buffer_desc_m,
275 make_tuple(I0),
276 inv_var_thread_buf);
277 }
278 else
279 {
280 const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
281 p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
282
283 const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
284 p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
285
286 const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
287 p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
288
289 auto threadwise_mean_var_load_m_k =
291 AccDataType,
292 MeanVarCountGridDesc_M_K,
293 decltype(thread_buffer_desc_m_1),
294 ThreadBufferLengths_M_1,
296 1,
297 1,
298 1,
299 true>(
300 mean_var_count_grid_desc_m_k,
301 make_multi_index(blkgroup_id * M_BlockTileSize +
302 thread_m_cluster_id * MThreadSliceSize,
303 thread_k_cluster_id * 1));
304
305 auto threadwise_count_load_m_k =
307 int32_t,
308 MeanVarCountGridDesc_M_K,
309 decltype(thread_buffer_desc_m_1),
310 ThreadBufferLengths_M_1,
312 1,
313 1,
314 1,
315 true>(
316 mean_var_count_grid_desc_m_k,
317 make_multi_index(blkgroup_id * M_BlockTileSize +
318 thread_m_cluster_id * MThreadSliceSize,
319 thread_k_cluster_id * 1));
320
321 constexpr auto mean_var_count_thread_copy_step_m_k =
322 make_multi_index(0, KThreadClusterSize * 1);
323
325 welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
326 welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
327 welford_count_thread_buf(I) = 0;
328 });
329
330 for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
331 ++reducedTiles)
332 {
333 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
334 welford_mean_global_buf,
335 thread_buffer_desc_m_1,
336 make_tuple(I0, I0),
337 in_welford_mean_thread_buf);
338
339 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
340 welford_var_global_buf,
341 thread_buffer_desc_m_1,
342 make_tuple(I0, I0),
343 in_welford_var_thread_buf);
344
345 threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
346 welford_count_global_buf,
347 thread_buffer_desc_m_1,
348 make_tuple(I0, I0),
349 in_welford_count_thread_buf);
350
351 ThreadwiseWelford::Run(in_welford_mean_thread_buf,
352 in_welford_var_thread_buf,
353 in_welford_count_thread_buf,
354 welford_mean_thread_buf,
355 welford_var_thread_buf,
356 welford_count_thread_buf);
357
358 threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
359 mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
360 threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
361 mean_var_count_thread_copy_step_m_k);
362 }
363
365 if constexpr(I > 0)
367
368 BlockwiseWelford::Run(welford_mean_thread_buf(I),
369 welford_var_thread_buf(I),
370 welford_count_thread_buf(I));
371 });
372
373 // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
375 welford_var_thread_buf(I) =
376 type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
377 });
378
379 if(block_local_id == 0 && thread_k_cluster_id == 0)
380 {
381
382 auto threadwise_mean_inv_var_store =
384 MeanVarDataType,
385 decltype(thread_buffer_desc_m),
386 MeanVarGridDesc_M,
388 ThreadBufferLengths_M,
390 0,
391 1,
393 1,
394 true>(
395 mean_var_grid_desc_m,
396 make_multi_index(blkgroup_id * M_BlockTileSize +
397 thread_m_cluster_id * MThreadSliceSize),
398 PassThroughOp{});
399
401 p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
402
403 auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
404 p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
405
406 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
407 make_tuple(I0),
408 mean_thread_buf,
409 mean_var_grid_desc_m,
410 mean_global_buf);
411
412 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
413 make_tuple(I0),
414 inv_var_thread_buf,
415 mean_var_grid_desc_m,
416 inv_var_global_buf);
417 };
418 };
419
420 const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
421
422 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
423 AccDataType,
424 XYGridDesc_M_K,
425 decltype(thread_buffer_desc_m_k),
426 ThreadBufferLengths_M_K,
428 XDyVectorDim,
429 XSrcVectorSize,
430 1,
431 true>(
432 x_grid_desc_m_k,
433 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
434 workSizePerBlock * block_local_id +
435 thread_k_cluster_id * KThreadSliceSize));
436
437 auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
438 AccDataType,
439 XYGridDesc_M_K,
440 decltype(thread_buffer_desc_m_k),
441 ThreadBufferLengths_M_K,
443 XDyVectorDim,
444 DySrcVectorSize,
445 1,
446 true>(
447 dy_grid_desc_m_k,
448 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
449 workSizePerBlock * block_local_id +
450 thread_k_cluster_id * KThreadSliceSize));
451
452 const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
453 p_x, x_grid_desc_m_k.GetElementSpaceSize());
454
455 const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
456 p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
457
458 constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
459
461 reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
462 reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
463 });
464
465 // clang-format off
466 // Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
467 // clang-format on
468
469 for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
470 {
471 threadwise_x_load.Run(x_grid_desc_m_k,
472 x_global_buf,
473 thread_buffer_desc_m_k,
474 make_tuple(I0, I0),
475 x_thread_buf);
476
477 threadwise_dy_load.Run(dy_grid_desc_m_k,
478 dy_global_buf,
479 thread_buffer_desc_m_k,
480 make_tuple(I0, I0),
481 dy_thread_buf);
482
485 constexpr auto offset =
486 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
487
488 dy_elementwise_op(dy_thread_buf(Number<offset>{}),
489 dy_thread_buf[Number<offset>{}]);
490
491 AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
492 inv_var_thread_buf[iM];
493
494 tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
495 });
496 });
497
498 ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
499 ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
500
501 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
502 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
503 };
504
506 if constexpr(I > 0)
508
509 BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
511 BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
512 });
513
514 auto threadwise_dscale_dbias_store =
516 DscaleDbiasDataType,
517 decltype(thread_buffer_desc_m_1),
518 DscaleDbiasGridDesc_M_G,
520 ThreadBufferLengths_M_1,
522 1,
523 1,
525 1,
526 true>(
527 dscale_dbias_grid_desc_m_g,
528 make_multi_index(blkgroup_id * M_BlockTileSize +
529 thread_m_cluster_id * MThreadSliceSize,
530 block_local_id),
531 PassThroughOp{});
532
533 auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
534 p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
535
536 auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
537 p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
538
539 if(thread_k_cluster_id == 0)
540 {
541 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
542 make_tuple(I0, I0),
543 reduce_dscale_thread_buf,
544 dscale_dbias_grid_desc_m_g,
545 reduce_dscale_global_buf);
546
547 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
548 make_tuple(I0, I0),
549 reduce_dbias_thread_buf,
550 dscale_dbias_grid_desc_m_g,
551 reduce_dbias_global_buf);
552 };
553 };
554};
555
556} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:107
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:105
static constexpr index_t K_BlockTileSize
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:150
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:120
ThreadwiseWelfordMerge< AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:123
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:118
static constexpr auto I0
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:146
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &dy_grid_desc_m_k, const MeanVarGridDesc_M &mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G &dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:157
static constexpr bool reorder_thread_cluster
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:103
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:144
static constexpr auto I1
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:147
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ck::reduce::Add, false > ThreadwiseReduce
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:138
static constexpr auto thread_cluster_desc
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:113
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:126
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:116
static constexpr index_t M_BlockTileSize
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:149
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ck::reduce::Add, false > BlockwiseReduce
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:131
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:110
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340