gridwise_gemm_xdl_cshuffle_streamk_v3.hpp Source File

gridwise_gemm_xdl_cshuffle_streamk_v3.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <typename GridwiseGemm,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
38 kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
39{
40#if defined(__gfx9__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
42 {
43 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
47 }
48#else
49 ignore = karg;
50#endif // end of if (defined(__gfx9__))
51}
52
53template <typename GridwiseGemm,
54 bool HasMainKBlockLoop,
55 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
56 index_t MinimumOccupancy = 1,
58__global__ void
59#if CK_USE_LAUNCH_BOUNDS
60__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
61#endif
62 kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
63{
64#if defined(__gfx9__) || defined(__gfx12__)
65 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
66 {
67 // Pass two lds pointer is the key to tell compiler that ds_read/write
68 // operate on different lds chunk at same time without order dependecy
69 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71
72 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
73 karg.p_a_grid,
74 karg.p_b_grid,
75 karg.p_c_grid,
76 p_shared_0,
77 p_shared_1,
78 karg,
79 karg.p_workspace_);
80 }
81#else
82 ignore = karg;
83#endif // end of if (defined(__gfx9__))
84}
85
86template <typename ALayout,
87 typename BLayout,
88 typename CLayout,
89 typename ADataType,
90 typename BDataType,
91 typename AccDataType,
92 typename CShuffleDataType,
93 typename CDataType,
94 typename AElementwiseOperation,
95 typename BElementwiseOperation,
96 typename CElementwiseOperation,
98 index_t BlockSize,
99 index_t MPerBlock,
100 index_t NPerBlock,
101 index_t KPerBlock,
102 index_t AK1Value,
103 index_t BK1Value,
104 index_t MPerXdl,
105 index_t NPerXdl,
106 index_t MXdlPerWave,
107 index_t NXdlPerWave,
108 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
109 typename ABlockTransferThreadClusterArrangeOrder,
110 typename ABlockTransferSrcAccessOrder,
111 index_t ABlockTransferSrcVectorDim,
112 index_t ABlockTransferSrcScalarPerVector,
113 index_t ABlockTransferDstScalarPerVector_AK1,
114 bool AThreadTransferSrcResetCoordinateAfterRun,
115 index_t ABlockLdsExtraM,
116 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
117 typename BBlockTransferThreadClusterArrangeOrder,
118 typename BBlockTransferSrcAccessOrder,
119 index_t BBlockTransferSrcVectorDim,
120 index_t BBlockTransferSrcScalarPerVector,
121 index_t BBlockTransferDstScalarPerVector_BK1,
122 bool BThreadTransferSrcResetCoordinateAfterRun,
123 index_t BBlockLdsExtraN,
124 index_t CShuffleMXdlPerWavePerShuffle,
125 index_t CShuffleNXdlPerWavePerShuffle,
126 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
127 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
130 typename ComputeTypeA = CDataType,
131 typename ComputeTypeB = ComputeTypeA>
133{
134 static constexpr auto I0 = Number<0>{};
135 static constexpr auto I1 = Number<1>{};
136 static constexpr auto I2 = Number<2>{};
137 static constexpr auto I3 = Number<3>{};
138 static constexpr auto I4 = Number<4>{};
139 static constexpr auto I5 = Number<5>{};
140 static constexpr auto I6 = Number<6>{};
141 static constexpr auto I7 = Number<7>{};
142
143 // K1 should be Number<...>
144 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
145 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
146 static constexpr auto AK1Number = Number<AK1Value>{};
147 static constexpr auto BK1Number = Number<BK1Value>{};
148
149 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
150 static constexpr bool is_single_rate_mfma =
152 lcm_AK1_BK1 <= 4) ||
155 lcm_AK1_BK1 < 32))
156 ? true
157 : false;
158 static constexpr auto is_scale_mfma = false;
159 static constexpr index_t KPack =
161 MfmaSelector<ComputeTypeA,
162 MPerXdl,
163 NPerXdl,
164 ComputeTypeA,
166 is_scale_mfma>::selected_mfma.k_per_blk);
167
169 __host__ static auto CalculateMPadded(index_t M)
170 {
171 return math::integer_least_multiple(M, MPerBlock);
172 }
173
174 __host__ static auto CalculateNPadded(index_t N)
175 {
176 return math::integer_least_multiple(N, NPerBlock);
177 }
178
179 __host__ static auto CalculateKPadded(index_t K)
180 {
181 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
182 }
183
184 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
185 {
186 auto K_t = K_Batch * KPerBlock;
187 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
188 }
189
190 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
191 {
192 auto K_t = K_Batch * KPerBlock;
193 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
194 }
195
196 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
197 {
198 auto K_t = K_Batch * KPerBlock;
199 return (K + K_t - 1) / K_t * KPerBlock;
200 }
201
202 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
203 {
204 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
205 auto K_t = K_Batch * KReadVec;
206 return (K + K_t - 1) / K_t * KReadVec;
207 }
208
209 __host__ static auto CalculateMBlock(index_t M)
210 {
211 return math::integer_divide_ceil(M, MPerBlock);
212 }
213
214 __host__ static auto CalculateNBlock(index_t N)
215 {
216 return math::integer_divide_ceil(N, NPerBlock);
217 }
218
219 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
220 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
221 {
222 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
223 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
224
226 TileDesc_K0_MN_K1{},
232 }
233
234 __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
235 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
236 {
237 const auto a_grid_desc_mraw_kraw = [&]() {
239 {
240 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
241 }
243 {
244 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
245 }
246 }();
247
248 // Pad both M and K to be multiples of the block sizes
249 const auto a_grid_desc_m_k =
250 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
252 make_right_pad_transform(K, KPad - K)),
255
256 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
257 a_grid_desc_m_k,
262
263 return a_grid_desc_ak0_m_ak1;
264#if 0
265 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
266
267 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
268 GemmSpec == GemmSpecialization::MNKPadding)
269 {
270 // pad both M and K
271 const auto a_grid_desc_m_k =
272 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
274 make_right_pad_transform(K, KPad - K)),
277
278 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
279 a_grid_desc_m_k,
284
285 return a_grid_desc_ak0_m_ak1;
286 }
287 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
288 GemmSpec == GemmSpecialization::MNPadding)
289 {
290 // pad M, but not K
291 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
292 a_grid_desc_mraw_kraw,
294 make_right_pad_transform(M, MPad - M)),
297
298 return a_grid_desc_ak0_m_ak1;
299 }
300 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
301 GemmSpec == GemmSpecialization::NKPadding)
302 {
303 // pad K, but not M
304 const auto a_grid_desc_m_k = transform_tensor_descriptor(
305 a_grid_desc_mraw_kraw,
309
310 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
311 a_grid_desc_m_k,
316
317 return a_grid_desc_ak0_m_ak1;
318 }
319 else
320 {
321 // not pad M or K
322 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
323 a_grid_desc_mraw_kraw,
328
329 return a_grid_desc_ak0_m_ak1;
330 }
331#endif
332 }
333
334 __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
335 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
336 {
337 const auto b_grid_desc_nraw_kraw = [&]() {
339 {
340 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
341 }
343 {
344 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
345 }
346 }();
347
348 // Pad both N and K to be multiples of the block sizes
349 const auto b_grid_desc_n_k =
350 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
352 make_right_pad_transform(K, KPad - K)),
355
356 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
357 b_grid_desc_n_k,
362
363 return b_grid_desc_bk0_n_bk1;
364#if 0
365 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
366
367 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
368 GemmSpec == GemmSpecialization::MNKPadding)
369 {
370 // pad both N and K
371 const auto b_grid_desc_n_k =
372 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
374 make_right_pad_transform(K, KPad - K)),
377
378 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
379 b_grid_desc_n_k,
384
385 return b_grid_desc_bk0_n_bk1;
386 }
387 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
388 GemmSpec == GemmSpecialization::MNPadding)
389 {
390 // pad N, but not K
391 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
392 b_grid_desc_nraw_kraw,
394 make_right_pad_transform(N, NPad - N)),
397
398 return b_grid_desc_bk0_n_bk1;
399 }
400 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
401 GemmSpec == GemmSpecialization::MKPadding)
402 {
403 // pad K, but not N
404 const auto b_grid_desc_n_k = transform_tensor_descriptor(
405 b_grid_desc_nraw_kraw,
409
410 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
411 b_grid_desc_n_k,
416
417 return b_grid_desc_bk0_n_bk1;
418 }
419 else
420 {
421 // not pad N or K
422 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
423 b_grid_desc_nraw_kraw,
428
429 return b_grid_desc_bk0_n_bk1;
430 }
431#endif
432 }
433
434 template <typename ABlockDesc_AK0_M_AK1>
435 __host__ __device__ static constexpr auto
436 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
437 {
438 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
439
440 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
441 }
442
443 template <typename BBlockDesc_BK0_N_BK1>
444 __host__ __device__ static constexpr auto
445 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
446 {
447 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
448
449 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
450 }
451
452 __host__ __device__ static auto
454 {
455 const auto c_grid_desc_mraw_nraw = [&]() {
457 {
458 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
459 }
461 {
462 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
463 }
464 }();
465
466 // Pad both M and N to be multiples of the block sizes
467 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
469 make_right_pad_transform(N, NPad - N)),
472#if 0
473 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
474
475 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
476 GemmSpec == GemmSpecialization::MNKPadding)
477 {
478 // pad M and N
479 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
481 make_right_pad_transform(N, NPad - N)),
484 }
485 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
486 GemmSpec == GemmSpecialization::MKPadding)
487 {
488 // pad M, but not N
490 c_grid_desc_mraw_nraw,
494 }
495 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
496 GemmSpec == GemmSpecialization::NKPadding)
497 {
498 // pad N, but not M
500 c_grid_desc_mraw_nraw,
504 }
505 else
506 {
507 // not pad M or N
508 return c_grid_desc_mraw_nraw;
509 }
510#endif
511 }
512
513 struct Problem
514 {
515 __host__ Problem(index_t M_,
516 index_t N_,
517 index_t K_,
518 index_t StrideA_,
519 index_t StrideB_,
520 index_t StrideC_,
521 index_t Streamk_sel_,
522 index_t Grid_size_,
523 StreamKReductionStrategy reduction_strategy_)
524 : M{M_},
525 N{N_},
526 K{K_},
527 StrideA{StrideA_},
528 StrideB{StrideB_},
529 StrideC{StrideC_},
530 Streamk_sel{Streamk_sel_},
531 Grid_size{Grid_size_},
532 reduction_strategy{reduction_strategy_}, // Initialize the member variable
535 KRead{CalculateKRead(K_, 1)},
537 AK0{CalculateAK0Padded(K_, 1)},
538 BK0{CalculateBK0Padded(K_, 1)},
541
542 {
543 }
544
545 __host__ void Print() const
546 {
547 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
548 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
549 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
550 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
551 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
552 << "NBlock: " << NBlock << ", " << "Stream-K Selection:" << Streamk_sel
553 << ", " << "Grid size:" << Grid_size << ", " << "Reduction Strategy:"
555 : "Reduction")
556 << "}" << std::endl;
557 }
558
576 };
577
578 // Argument
580 {
581 __host__ Argument(const ADataType* p_a_grid_,
582 const BDataType* p_b_grid_,
583 CDataType* p_c_grid_,
584 index_t M_,
585 index_t N_,
586 index_t K_,
587 index_t StrideA_,
588 index_t StrideB_,
589 index_t StrideC_,
590 index_t Streamk_sel_,
591 index_t Grid_size_,
592 StreamKReductionStrategy reduction_strategy_)
593 : Problem{M_,
594 N_,
595 K_,
596 StrideA_,
597 StrideB_,
598 StrideC_,
599 Streamk_sel_,
600 Grid_size_,
601 reduction_strategy_},
602 p_a_grid{p_a_grid_},
603 p_b_grid{p_b_grid_},
604 p_c_grid{p_c_grid_},
606 N_,
608 Grid_size_,
609 Streamk_sel_,
610 reduction_strategy_)
611
612 {
613 }
614
615 const ADataType* p_a_grid;
616 const BDataType* p_b_grid;
617 CDataType* p_c_grid;
619 NPerBlock,
620 KPerBlock,
622 8,
623 4>
625 };
626
628 {
629 __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K)
630 {
632 {
633 a_k_split_offset = kbatch_id * problem.KRead;
634 }
636 {
637 a_k_split_offset = kbatch_id * problem.KRead * problem.M;
638 }
639
641 {
642 b_k_split_offset = kbatch_id * problem.KRead * problem.N;
643 }
645 {
646 b_k_split_offset = kbatch_id * problem.KRead;
647 }
648
649 if(kbatch_id < static_cast<uint32_t>(problem.KBatch - 1))
650 {
651 problem.K = problem.KRead;
652 }
653 else
654 {
655 problem.K = orig_K - problem.KRead * (problem.KBatch - 1);
656 }
657 }
658
661 };
662
663 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
664 {
665 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
666 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
667 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
668 // A matrix in LDS memory, dst of blockwise copy
669 if constexpr(ABlockLdsExtraM)
670 {
674 }
675 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
676 // in some cases.
678 {
679 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
680 ? 1
681 : 32 * 4 / KPerBlock / sizeof(ADataType);
682 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
684 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
686
687 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
688 a_lds_block_desc,
694
695 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
696 a_lds_block_desc_permuted,
702
703 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
704 a_lds_block_desc_ak0_mldslayer_m_ak1,
711
712 return a_lds_block_desc_ak0_m_ak1;
713 }
714 else // ColumnMajor A
715 {
716 // kfold and mpair dimension is not always required.
717 // more dimension in merge_transform increase the difficulty of generating immarg offset
718 // for compiler.
719 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
720 constexpr auto M1 = MPerBlock / M0;
721
722 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
723 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
724 constexpr auto KThreadRead = WaveSize / MPerXdl;
725 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
726
727 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
728 ? 1
729 : 128 / (AK1Number * M0 * sizeof(ADataType));
730 constexpr auto KThreadReadPerm =
731 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
732 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
733 : KThreadRead;
734
735 // 1<=mpair<=n0
736 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
737 ? 1
738 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
739 ? M0
740 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
741
742 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
746 Number<kfold * M0 / mpair>{},
748 AK1Number));
749
750 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
751 a_lds_block_desc,
756 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
763
764 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
765 a_lds_block_desc_permuted,
774 Sequence<1>{},
775 Sequence<2>{},
776 Sequence<3>{},
777 Sequence<4>{},
778 Sequence<5>{}),
780 Sequence<2>{},
783 Sequence<6>{},
784 Sequence<7>{}));
785
786 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
787 a_lds_block_desc_unmerged,
790 Number<KThreadWrite / kfold / KThreadReadPerm>{},
798
799 return a_lds_block_desc_ak0_m_ak1;
800 }
801 }
802
803 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
804 {
805 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
806 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
807 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
808 // B matrix in LDS memory, dst of blockwise copy
809 if constexpr(BBlockLdsExtraN)
810 {
814 }
816 {
817 // NLdsLayer * K0 as logical Bank
818 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
819 ? 1
820 : 32 * 4 / KPerBlock / sizeof(BDataType);
821 ;
822 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
824 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
826
827 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
828 b_lds_block_desc,
834
835 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
836 b_lds_block_desc_permuted,
842
843 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
844 b_lds_block_desc_bk0_nldslayer_n_bk1,
851
852 return b_lds_block_desc_bk0_n_bk1;
853 }
854 else // RowMajor B
855 {
856 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
857 constexpr auto N1 = NPerBlock / N0;
858
859 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
860 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
861 constexpr auto KThreadRead = WaveSize / NPerXdl;
862 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
863
864 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
865 ? 1
866 : 128 / (BK1Number * N0 * sizeof(BDataType));
867 constexpr auto KThreadReadPerm =
868 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
869 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
870 : KThreadRead;
871
872 // 1<=npair<=n0
873 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
874 ? 1
875 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
876 ? N0
877 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
878
879 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
883 Number<kfold * N0 / npair>{},
885 BK1Number));
886
887 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
888 b_lds_block_desc,
893 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
900
901 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
902 b_lds_block_desc_permuted,
911 Sequence<1>{},
912 Sequence<2>{},
913 Sequence<3>{},
914 Sequence<4>{},
915 Sequence<5>{}),
917 Sequence<2>{},
920 Sequence<6>{},
921 Sequence<7>{}));
922
923 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
924 b_lds_block_desc_unmerged,
927 Number<KThreadWrite / kfold / KThreadReadPerm>{},
935
936 return b_lds_block_desc_bk0_n_bk1;
937 }
938 }
939
941 {
942 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
943 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
944
945 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
949 I1,
951
952 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
953 }
954
955 __host__ __device__ static constexpr auto
957 {
958 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
959 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
960
964 Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
966 }
967
970 BlkGemmPipelineVer,
971 BlkGemmPipeSched,
972 BlockSize,
973 ADataType,
974 BDataType,
975 ComputeTypeA,
976 AccDataType,
983 ABlockTransferSrcScalarPerVector,
984 BBlockTransferSrcScalarPerVector,
985 MPerBlock,
986 NPerBlock,
987 KPerBlock,
988 MPerXdl,
989 NPerXdl,
990 MXdlPerWave,
991 NXdlPerWave,
992 KPack>())>;
993
994 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
995 {
996 // LDS allocation for A and B: be careful of alignment
997 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
998 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
999
1000 // lds max alignment
1001 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1002
1003 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1004 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1005
1006 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1007 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1008
1009 // LDS allocation for C shuffle in LDS
1010 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1012
1013 constexpr auto c_block_size =
1014 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1015
1016 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1017 b_block_space_size_aligned * sizeof(BDataType)),
1018 c_block_size * sizeof(CShuffleDataType));
1019 }
1020
1022
1023 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1024 __host__ static constexpr bool CheckValidity(const Argument& karg)
1025 {
1026 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1027 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1028 "Invalid tuning param!");
1029
1030 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1035 {
1036 if(!(karg.M % MPerBlock == 0))
1037 {
1038 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1039 {
1040 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1041 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1042 << std::endl;
1043 }
1044 return false;
1045 }
1046 }
1047
1048 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1053 {
1054 if(!(karg.N % NPerBlock == 0))
1055 {
1056 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1057 {
1058 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1059 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1060 << std::endl;
1061 }
1062 return false;
1063 }
1064 }
1065
1066 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1070 {
1071
1072 auto K_t = KPerBlock;
1073 if(!(karg.K % K_t == 0))
1074 {
1075 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1076 {
1077 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1078 << karg.K << " " << __FILE__ << ":" << __LINE__
1079 << ", in function: " << __func__ << std::endl;
1080 }
1081 return false;
1082 }
1083 }
1084 else
1085 {
1086
1087 if(karg.K <= 0)
1088 {
1089 return false;
1090 }
1091 }
1092
1094 {
1095 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1096 {
1097 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1098 {
1099 std::cout << "Arg K (" << karg.K
1100 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1101 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1102 << __LINE__ << ", in function: " << __func__ << std::endl;
1103 }
1104 return false;
1105 }
1106 }
1107 else
1108 {
1109 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1110 {
1111 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1112 {
1113 std::cout << "Arg M (" << karg.M
1114 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1115 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1116 << __LINE__ << ", in function: " << __func__ << std::endl;
1117 }
1118
1119 return false;
1120 }
1121 }
1122
1124 {
1125 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1126 {
1127 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1128 {
1129 std::cout << "Arg N (" << karg.N
1130 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1131 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1132 << __LINE__ << ", in function: " << __func__ << std::endl;
1133 }
1134 std::cout << "Arg N (" << karg.N
1135 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1136 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1137 << __LINE__ << ", in function: " << __func__ << std::endl;
1138 return false;
1139 }
1140 }
1141 else
1142 {
1143 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1144 {
1145 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1146 {
1147 std::cout << "Arg K (" << karg.K
1148 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1149 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1150 << __LINE__ << ", in function: " << __func__ << std::endl;
1151 }
1152
1153 return false;
1154 }
1155 }
1156
1158 {
1159 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1160 {
1161 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1162 {
1163 std::cout << "Arg N (" << karg.N
1164 << ") value is not a multiple of "
1165 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1166 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1167 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1168 << std::endl;
1169 }
1170
1171 return false;
1172 }
1173 }
1174 else
1175 {
1176 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1177 {
1178 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1179 {
1180 std::cout << "Arg M (" << karg.M
1181 << ") value is not a multiple of "
1182 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1183 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1184 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1185 << std::endl;
1186 }
1187
1188 return false;
1189 }
1190 }
1191
1192 // check gridwise gemm pipeline
1193 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1194
1195 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1196 {
1197 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1198 {
1199 return false;
1200 }
1201 }
1202
1203 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1204 return true;
1205 }
1206
1207 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1208 {
1209 const index_t num_loop = K / KPerBlock;
1210
1211 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1212 }
1213
1214 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1215 {
1216 const index_t num_loop = K / KPerBlock;
1217
1218 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1219 }
1220
1221 template <typename CGridDesc>
1223 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1224 {
1225 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1226 c_grid_desc_m_n,
1231
1232 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1233 }
1234
1235 __host__ __device__ static constexpr auto GetClusterLengthReduction()
1236 {
1237 // TODO: assume C is row major
1238 // TODO: we always first loop over N, then M
1239 constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
1240 constexpr auto NPerBlockReduction =
1241 NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
1242 constexpr auto MPerBlockReduction =
1243 (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
1245 }
1246
1247 __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
1248 {
1249 const auto c_partial_acc_block_m_n = [&]() {
1251 {
1252 return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
1253 make_tuple(NPerBlock, I1));
1254 }
1256 {
1257 return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
1258 make_tuple(I1, MPerBlock));
1259 }
1260 }();
1261 return c_partial_acc_block_m_n;
1262 }
1264 NPerBlock,
1265 KPerBlock,
1267 8,
1268 4>;
1269
1270 template <bool HasMainKBlockLoop,
1271 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1272 TailNumber TailNum = TailNumber::Odd>
1273 __device__ static void Run(const ADataType* p_a_grid,
1274 const BDataType* p_b_grid,
1275 CDataType* p_c_grid,
1276 void* p_shared,
1277 Problem& problem,
1278 void* p_workspace)
1279 {
1280 const AElementwiseOperation a_element_op{};
1281 const BElementwiseOperation b_element_op{};
1282 const CElementwiseOperation c_element_op{};
1283
1284 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1285 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1286 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1287 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1288
1289 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1290 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1291
1292 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1293 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1294
1295 Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
1296 problem.N,
1297 AK0Number * problem.KPadded,
1298 problem.Grid_size,
1299 problem.Streamk_sel,
1300 problem.reduction_strategy);
1301 uint32_t iter_start, iter_end;
1302 bool is_sk_block, is_dp_block, is_reduction_block;
1303 index_t num_k_block_main_loop;
1304 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1305 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1306 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1308 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1310 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1311
1312 uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
1313 reinterpret_cast<char*>(p_workspace) +
1314 block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
1315
1316 for(auto block_idx = get_block_1d_id();
1317 block_idx < block_2_ctile_map_streamk.get_grid_dims();
1318 block_idx += gridDim.x)
1319 {
1320
1321 is_sk_block =
1322 static_cast<uint32_t>(block_idx) < block_2_ctile_map_streamk.sk_num_blocks;
1323 is_dp_block =
1324 static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
1325 static_cast<uint32_t>(block_idx) <
1326 block_2_ctile_map_streamk.reduction_start_block_idx;
1327
1328 block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
1329 num_k_block_main_loop = iter_end - iter_start;
1330
1331 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
1332 {
1333 is_reduction_block = static_cast<uint32_t>(block_idx) >=
1334 block_2_ctile_map_streamk.reduction_start_block_idx;
1335 if(is_reduction_block)
1336 {
1337 // descriptors
1338 constexpr auto cluster_length_reduce = GetClusterLengthReduction();
1339 constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
1340 const auto reduce_thread_cluster_idx =
1341 reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
1342 const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
1343 const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
1344
1345 constexpr auto MReduceIters = math::integer_divide_ceil(
1346 Number<MPerBlock>{}, cluster_length_reduce.At(I0));
1347 constexpr auto NReduceIters = math::integer_divide_ceil(
1349 cluster_length_reduce.At(I1) *
1351
1352 constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
1354 constexpr auto acc_thread_buf_store_desc =
1357
1358 constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
1359
1360 constexpr auto partial_acc_load_step_n =
1362 cluster_length_reduce.At(I1) *
1363 CShuffleBlockTransferScalarPerVector_NPerBlock);
1364 constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
1365 0,
1366 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
1367 CShuffleBlockTransferScalarPerVector_NPerBlock);
1368 constexpr auto partial_acc_load_step_m =
1369 make_multi_index(cluster_length_reduce.At(I0), 0);
1370
1371 constexpr auto partial_acc_store_step_n =
1373 0,
1374 0,
1375 cluster_length_reduce.At(I1) *
1376 CShuffleBlockTransferScalarPerVector_NPerBlock);
1377 constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
1378 0,
1379 0,
1380 0,
1381 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
1382 CShuffleBlockTransferScalarPerVector_NPerBlock);
1383 constexpr auto partial_acc_store_step_m =
1384 make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
1385
1387 AccDataType,
1388 CShuffleBlockTransferScalarPerVector_NPerBlock,
1389 true>
1390 parcial_acc_buf;
1392 AccDataType,
1393 CShuffleBlockTransferScalarPerVector_NPerBlock,
1394 true>
1395 acc_buf;
1396
1397 // start to compute
1398 auto reduction_idx =
1399 block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
1400 auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
1401 reduction_idx, problem.M, problem.N);
1402
1403 workgroup_barrier wg_barrier(p_semaphore);
1404
1405 uint32_t tile_acc_offset_start =
1406 block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
1407 uint32_t tile_acc_offset_end =
1408 block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1409 1);
1410 __syncthreads();
1411
1412 auto acc_load = ThreadwiseTensorSliceTransfer_v2<
1413 AccDataType, // SrcData,
1414 AccDataType, // DstData,
1415 decltype(c_partial_acc_block_m_n), // SrcDesc,
1416 decltype(acc_thread_buf_load_desc), // DstDesc,
1417 Sequence<1,
1418 CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
1419 Sequence<0, 1>, // DimAccessOrder,
1420 1, // SrcVectorDim,
1421 CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1422 1, // SrcScalarStrideInVector,
1423 false // SrcResetCoordinateAfterRun,
1424 >{c_partial_acc_block_m_n,
1425 make_multi_index(thread_m_cluster_id,
1426 thread_n_cluster_id *
1427 CShuffleBlockTransferScalarPerVector_NPerBlock)};
1428
1429 auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
1430 AccDataType, // SrcData,
1431 CDataType, // DstData,
1432 decltype(acc_thread_buf_store_desc), // SrcDesc,
1433 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
1434 CElementwiseOperation, // ElementwiseOperation,
1435 Sequence<1,
1436 1,
1437 1,
1438 CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
1439 Sequence<0, 1, 2, 3>, // DimAccessOrder,
1440 3, // DstVectorDim,
1441 CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
1442 InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1443 1, // DstScalarStrideInVector,
1444 false // DstResetCoordinateAfterRun,
1445 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
1446 make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
1447 thread_m_cluster_id,
1448 __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
1449 thread_n_cluster_id *
1450 CShuffleBlockTransferScalarPerVector_NPerBlock),
1451 CElementwiseOperation{}};
1452
1453 wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
1454
1455 if(threadIdx.x == 0)
1456 {
1457 p_semaphore[reduction_idx] = 0;
1458 }
1459 using Accumulation = ck::detail::
1460 AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
1461
1462 for(int i_m = 0; i_m < MReduceIters; i_m++)
1463 {
1464 static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
1465 acc_buf.Clear();
1466 for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
1467 {
1468 auto c_partial_acc_buf =
1471 reinterpret_cast<AccDataType*>(p_workspace) +
1472 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
1473 c_partial_acc_block_m_n.GetElementSpaceSize());
1474
1475 acc_load.Run(c_partial_acc_block_m_n,
1476 c_partial_acc_buf,
1477 acc_thread_buf_load_desc,
1478 make_tuple(I0, I0),
1479 parcial_acc_buf);
1480
1482 [&](auto i_vec) {
1483 constexpr auto offset =
1484 acc_thread_buf_load_desc.CalculateOffset(
1485 make_tuple(0, i_vec));
1486 Accumulation::Calculate(acc_buf(Number<offset>{}),
1487 parcial_acc_buf[Number<offset>{}]);
1488 });
1489 }
1490
1491 if(thread_n_cluster_id *
1492 CShuffleBlockTransferScalarPerVector_NPerBlock <
1493 NPerBlock)
1494 {
1495 acc_store.Run(acc_thread_buf_store_desc,
1496 make_tuple(I0, I0, I0, I0),
1497 acc_buf,
1498 c_grid_desc_mblock_mperblock_nblock_nperblock,
1499 c_grid_buf);
1500 }
1501 if constexpr(NReduceIters != 1)
1502 {
1503 if constexpr(i_n_reduce != (NReduceIters - 1))
1504 {
1505 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
1506 partial_acc_load_step_n);
1507 acc_store.MoveDstSliceWindow(
1508 c_grid_desc_mblock_mperblock_nblock_nperblock,
1509 partial_acc_store_step_n);
1510 }
1511 else
1512 {
1513 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
1514 partial_acc_load_step_n_reverse);
1515 acc_store.MoveDstSliceWindow(
1516 c_grid_desc_mblock_mperblock_nblock_nperblock,
1517 partial_acc_store_step_n_reverse);
1518 }
1519 }
1520 });
1521 {
1522 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
1523 partial_acc_load_step_m);
1524 acc_store.MoveDstSliceWindow(
1525 c_grid_desc_mblock_mperblock_nblock_nperblock,
1526 partial_acc_store_step_m);
1527 }
1528 }
1529
1530 continue;
1531 }
1532 }
1533
1534 // offset for last acc buffer of this block
1535 uint32_t block_acc_offset =
1536 (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
1537 MPerBlock * NPerBlock;
1538 while(true)
1539 {
1540 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
1541 block_2_ctile_map_streamk.get_current_iter_length(
1542 iter_start, iter_end, num_k_block_main_loop));
1543 uint32_t tile_idx, iter_offset;
1544 block_2_ctile_map_streamk.get_tile_idx_with_offset(
1545 iter_end - 1, tile_idx, iter_offset);
1546 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
1547
1548 auto block_work_idx =
1549 block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
1550
1551 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1552 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1553
1554 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1555 const index_t m_block_data_idx_on_grid =
1556 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1557
1558 const index_t n_block_data_idx_on_grid =
1559 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1560
1561 const index_t k0_block_data_idx_on_grid =
1562 __builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
1563
1564 // lds max alignment
1565 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1566
1567 // A matrix in LDS memory, dst of blockwise copy
1568 constexpr auto a_block_desc_ak0_m_ak1 =
1570
1571 // B matrix in LDS memory, dst of blockwise copy
1572 constexpr auto b_block_desc_bk0_n_bk1 =
1574
1575 // A matrix blockwise copy
1576 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
1578 AElementwiseOperation,
1582 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1583 ABlockTransferThreadClusterArrangeOrder,
1584 ADataType,
1585 ADataType,
1586 decltype(a_grid_desc_ak0_m_ak1),
1587 decltype(a_block_desc_ak0_m_ak1),
1588 ABlockTransferSrcAccessOrder,
1590 ABlockTransferSrcVectorDim,
1591 2,
1592 ABlockTransferSrcScalarPerVector,
1593 ABlockTransferDstScalarPerVector_AK1,
1594 1,
1595 1,
1596 AThreadTransferSrcResetCoordinateAfterRun,
1597 true,
1598 BlockwiseGemmPipe::GlobalBufferNum>(
1599 a_grid_desc_ak0_m_ak1,
1600 make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
1601 a_element_op,
1602 a_block_desc_ak0_m_ak1,
1603 make_multi_index(0, 0, 0),
1605
1606 // B matrix blockwise copy
1607 auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
1609 BElementwiseOperation,
1613 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1614 BBlockTransferThreadClusterArrangeOrder,
1615 BDataType,
1616 BDataType,
1617 decltype(b_grid_desc_bk0_n_bk1),
1618 decltype(b_block_desc_bk0_n_bk1),
1619 BBlockTransferSrcAccessOrder,
1621 BBlockTransferSrcVectorDim,
1622 2,
1623 BBlockTransferSrcScalarPerVector,
1624 BBlockTransferDstScalarPerVector_BK1,
1625 1,
1626 1,
1627 BThreadTransferSrcResetCoordinateAfterRun,
1628 true,
1629 BlockwiseGemmPipe::GlobalBufferNum>(
1630 b_grid_desc_bk0_n_bk1,
1631 make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
1632 b_element_op,
1633 b_block_desc_bk0_n_bk1,
1634 make_multi_index(0, 0, 0),
1636
1637 // LDS allocation for A and B: be careful of alignment
1638 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1639 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1640
1641 // Cast after lds
1643 static_cast<ADataType*>(p_shared),
1644 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1645
1647 static_cast<BDataType*>(p_shared) +
1648 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
1649 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1650
1651 constexpr auto a_block_slice_copy_step =
1652 make_multi_index(KPerBlock / AK1Number, 0, 0);
1653 constexpr auto b_block_slice_copy_step =
1654 make_multi_index(KPerBlock / BK1Number, 0, 0);
1655
1656 // Blockwise GEMM pipeline
1657 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1658 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1659 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1660
1661 num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1662 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1663 KPerBlock);
1664
1665 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1666 a_grid_desc_ak0_m_ak1,
1667 a_block_desc_ak0_m_ak1,
1668 a_blockwise_copy,
1669 a_grid_buf,
1670 a_block_buf,
1671 a_block_slice_copy_step,
1672 b_grid_desc_bk0_n_bk1,
1673 b_block_desc_bk0_n_bk1,
1674 b_blockwise_copy,
1675 b_grid_buf,
1676 b_block_buf,
1677 b_block_slice_copy_step,
1678 c_thread_buf,
1679 num_k_block_main_loop);
1680
1681 // shuffle C and write out
1682 {
1683 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1684 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1685 "wrong!");
1686
1687 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1688 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1689
1690 // TODO: hacky, fix it!
1691 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1692 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1693
1694 // TODO: hacky, fix it!
1695 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1696 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1697 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1698
1699 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1700 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1701 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1702 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1703 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1704 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1705 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1706 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1707
1708 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1710
1711 constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
1713
1714 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1715 static_cast<CShuffleDataType*>(p_shared),
1716 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
1717 .GetElementSpaceSize());
1718
1719 auto c_partial_acc_buf =
1721 reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
1722 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
1723 .GetElementSpaceSize());
1724
1725 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1727 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1728 make_tuple(
1731 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
1732 // shuffle
1733 M1, // M1 = MWave
1734 M2, // M2 * M3 * M4 = MPerXdl
1735 M3,
1736 M4)),
1739 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
1740 // shuffle
1741 N1, // N1 = NWave
1742 N2))), // N2 = NPerXdl
1746 Sequence<>{},
1748
1749 // calculate origin of thread output tensor on global memory
1750 // blockwise GEMM c matrix starting index
1751 const auto c_thread_mtx_on_block =
1752 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1753
1754 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1755 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1756
1757 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1759 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1762
1763 const auto m_thread_data_on_block_idx =
1764 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1765 make_multi_index(m_thread_data_on_block));
1766
1767 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1772
1773 const auto n_thread_data_on_block_idx =
1774 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1775 make_multi_index(n_thread_data_on_block));
1776
1777 // shuffle: threadwise copy C from VGPR to LDS
1778 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1779 AccDataType,
1780 CShuffleDataType,
1781 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1782 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1784 Sequence<CShuffleMXdlPerWavePerShuffle,
1785 CShuffleNXdlPerWavePerShuffle,
1786 I1,
1787 I1,
1788 M2,
1789 I1,
1790 M4,
1791 I1>,
1793 7,
1794 1,
1796 1,
1797 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1799 0,
1800 m_thread_data_on_block_idx[I1],
1801 n_thread_data_on_block_idx[I1],
1802 m_thread_data_on_block_idx[I2],
1803 m_thread_data_on_block_idx[I3],
1804 m_thread_data_on_block_idx[I4],
1805 n_thread_data_on_block_idx[I2]),
1807
1808 // shuffle: blockwise copy C from LDS to global
1809 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
1810 ThisThreadBlock, // ThreadGroup
1811 CElementwiseOperation, // ElementwiseOperation,
1812 // CGlobalMemoryDataOperation, // DstInMemOp,
1813 Sequence<1,
1814 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1815 1,
1816 CShuffleNXdlPerWavePerShuffle * NWave *
1817 NPerXdl>, // BlockSliceLengths,
1818 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1819 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1820 CShuffleDataType, // typename SrcData,
1821 CDataType, // typename DstData,
1822 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1823 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1824 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1825 3, // index_t VectorDim,
1826 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1827 false, // bool ThreadTransferSrcResetCoordinateAfterRun,
1828 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1829 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1830 make_multi_index(0, 0, 0, 0),
1831 c_grid_desc_mblock_mperblock_nblock_nperblock,
1832 make_multi_index(block_m_id, 0, block_n_id, 0),
1833 c_element_op};
1834 // LDS to global partial acc
1835 auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
1836 ThisThreadBlock, // index_t BlockSize,
1837 CElementwiseOperation, // ElementwiseOperation,
1838 // InMemoryDataOperationEnum::Set, // DstInMemOp,
1839 Sequence<1,
1840 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1841 1,
1842 CShuffleNXdlPerWavePerShuffle * NWave *
1843 NPerXdl>, // BlockSliceLengths,
1844 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1845 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1846 CShuffleDataType, // typename SrcData,
1847 AccDataType, // typename DstData,
1848 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1849 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
1850 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1851 3, // index_t VectorDim,
1852 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1853 false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
1854 // false, othre wise has scratch
1855 false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
1856 // false, othre wise has scratch
1857 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1858 make_multi_index(0, 0, 0, 0),
1859 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1860 make_multi_index(0, 0, 0, 0),
1861 c_element_op};
1862 // space filling curve for threadwise C in VGPR
1863 constexpr auto sfc_c_vgpr =
1866 Sequence<CShuffleMXdlPerWavePerShuffle,
1867 CShuffleNXdlPerWavePerShuffle,
1868 1,
1869 1,
1870 M2,
1871 1,
1872 M4,
1873 1>>{};
1874
1875 // space filling curve for shuffled blockwise C in global mem
1876 constexpr auto sfc_c_global = SpaceFillingCurve<
1879 Sequence<1,
1880 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1881 1,
1882 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1883
1884 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1885
1886 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1887
1888 static_for<0, num_access, 1>{}([&](auto access_id) {
1889 // make sure it's safe to write to LDS
1891
1892 // each thread write its data from VGPR to LDS
1893 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1894 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1895 c_thread_buf,
1896 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1897 c_shuffle_block_buf);
1898
1899 // make sure it's safe to read from LDS
1901 c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin(
1902 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1903 make_tuple(0, 0, 0, 0));
1904
1905 if(is_dp_block)
1906 {
1907 // each block copy its data from LDS to global
1908 c_shuffle_block_copy_lds_to_global
1909 .template Run<decltype(c_shuffle_block_buf),
1910 decltype(c_grid_buf),
1912 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1913 c_shuffle_block_buf,
1914 c_grid_desc_mblock_mperblock_nblock_nperblock,
1915 c_grid_buf);
1916 }
1917 else if(is_sk_block)
1918 {
1919 if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
1920 {
1921 // each block copy its data from LDS to global
1922 c_shuffle_block_copy_lds_to_global
1923 .template Run<decltype(c_shuffle_block_buf),
1924 decltype(c_grid_buf),
1926 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1927 c_shuffle_block_buf,
1928 c_grid_desc_mblock_mperblock_nblock_nperblock,
1929 c_grid_buf);
1930 }
1931 else if(problem.reduction_strategy ==
1933 {
1934 // constexpr offset
1935 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1936 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1937 make_tuple(0, 0, 0, 0));
1938
1939 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1940 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1941 make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
1942
1943 c_block_copy_lds_to_partial_acc
1944 .template Run<decltype(c_shuffle_block_buf),
1945 decltype(c_partial_acc_buf),
1947 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1948 c_shuffle_block_buf,
1949 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1950 c_partial_acc_buf);
1951 }
1952 }
1953
1954 if constexpr(access_id < num_access - 1)
1955 {
1956 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1957
1958 // move on C
1959 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1960 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1961 }
1962 });
1963
1964 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
1965 {
1966 if(is_sk_block)
1967 {
1968 // increase the counter for this tile
1969 workgroup_barrier wg_barrier(p_semaphore);
1970 wg_barrier.inc(tile_idx);
1971 }
1972 }
1973 } // shuffle c and write-out end
1974
1975 // exit condition
1976 iter_end -= current_iter_length;
1977 if(iter_end <= iter_start)
1978 break;
1979 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
1980 {
1981 block_acc_offset -= MPerBlock * NPerBlock;
1982 }
1983 // make sure next loop LDS is ready for use
1985 } // while loop
1986
1987 } // for loop
1988 }
1989
1990 template <bool HasMainKBlockLoop,
1991 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1992 TailNumber TailNum = TailNumber::Odd>
1993 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1994 const BDataType* p_b_grid,
1995 CDataType* p_c_grid,
1996 void* p_shared_0,
1997 void* p_shared_1,
1998 Problem& problem,
1999 void* p_workspace)
2000 {
2001
2002 const AElementwiseOperation a_element_op{};
2003 const BElementwiseOperation b_element_op{};
2004 const CElementwiseOperation c_element_op{};
2005
2006 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2007 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2008 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2009 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2010
2011 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2012 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2013 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2014 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
2015
2016 uint32_t iter_start, iter_end;
2017 bool is_sk_block, is_dp_block, is_reduction_block;
2018 index_t num_k_block_main_loop;
2019
2020 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2021 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2022
2023 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2025 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2026
2028 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2029
2030 Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
2031 problem.N,
2032 AK0Number * problem.KPadded,
2033 problem.Grid_size,
2034 problem.Streamk_sel,
2035 problem.reduction_strategy);
2036 for(auto block_idx = get_block_1d_id();
2037 block_idx < block_2_ctile_map_streamk.get_grid_dims();
2038 block_idx += gridDim.x)
2039 {
2040 is_sk_block =
2041 static_cast<uint32_t>(block_idx) < block_2_ctile_map_streamk.sk_num_blocks;
2042 is_dp_block =
2043 static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
2044 static_cast<uint32_t>(block_idx) <
2045 block_2_ctile_map_streamk.reduction_start_block_idx;
2046
2047 block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
2048 num_k_block_main_loop = iter_end - iter_start;
2049
2050 uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
2051 reinterpret_cast<char*>(p_workspace) +
2052 block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
2053
2054 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
2055 {
2056 is_reduction_block = static_cast<uint32_t>(block_idx) >=
2057 block_2_ctile_map_streamk.reduction_start_block_idx;
2058 if(is_reduction_block)
2059 {
2060 // descriptors
2061 constexpr auto cluster_length_reduce = GetClusterLengthReduction();
2062 constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
2063 const auto reduce_thread_cluster_idx =
2064 reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
2065 const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
2066 const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
2067
2068 constexpr auto MReduceIters = math::integer_divide_ceil(
2069 Number<MPerBlock>{}, cluster_length_reduce.At(I0));
2070 constexpr auto NReduceIters = math::integer_divide_ceil(
2072 cluster_length_reduce.At(I1) *
2074
2075 constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
2077 constexpr auto acc_thread_buf_store_desc =
2080
2081 constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
2082
2083 constexpr auto partial_acc_load_step_n =
2085 cluster_length_reduce.At(I1) *
2086 CShuffleBlockTransferScalarPerVector_NPerBlock);
2087 constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
2088 0,
2089 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
2090 CShuffleBlockTransferScalarPerVector_NPerBlock);
2091 constexpr auto partial_acc_load_step_m =
2092 make_multi_index(cluster_length_reduce.At(I0), 0);
2093
2094 constexpr auto partial_acc_store_step_n =
2096 0,
2097 0,
2098 cluster_length_reduce.At(I1) *
2099 CShuffleBlockTransferScalarPerVector_NPerBlock);
2100 constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
2101 0,
2102 0,
2103 0,
2104 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
2105 CShuffleBlockTransferScalarPerVector_NPerBlock);
2106 constexpr auto partial_acc_store_step_m =
2107 make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
2108
2110 AccDataType,
2111 CShuffleBlockTransferScalarPerVector_NPerBlock,
2112 true>
2113 parcial_acc_buf;
2115 AccDataType,
2116 CShuffleBlockTransferScalarPerVector_NPerBlock,
2117 true>
2118 acc_buf;
2119
2120 // start to compute
2121 auto reduction_idx =
2122 block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
2123 auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
2124 reduction_idx, problem.M, problem.N);
2125
2126 workgroup_barrier wg_barrier(p_semaphore);
2127
2128 uint32_t tile_acc_offset_start =
2129 block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
2130 uint32_t tile_acc_offset_end =
2131 block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
2132 1);
2133
2134 uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
2135
2136 if(threadIdx.x == 0)
2137 {
2138 p_semaphore[reduction_idx] = 0;
2139 }
2140
2141 __syncthreads();
2142
2143 auto acc_load = ThreadwiseTensorSliceTransfer_v2<
2144 AccDataType, // SrcData,
2145 AccDataType, // DstData,
2146 decltype(c_partial_acc_block_m_n), // SrcDesc,
2147 decltype(acc_thread_buf_load_desc), // DstDesc,
2148 Sequence<1,
2149 CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
2150 Sequence<0, 1>, // DimAccessOrder,
2151 1, // SrcVectorDim,
2152 CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
2153 1, // SrcScalarStrideInVector,
2154 false // SrcResetCoordinateAfterRun,
2155 >{c_partial_acc_block_m_n,
2156 make_multi_index(thread_m_cluster_id,
2157 thread_n_cluster_id *
2158 CShuffleBlockTransferScalarPerVector_NPerBlock)};
2159
2160 auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
2161 AccDataType, // SrcData,
2162 CDataType, // DstData,
2163 decltype(acc_thread_buf_store_desc), // SrcDesc,
2164 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
2165 CElementwiseOperation, // ElementwiseOperation,
2166 Sequence<1,
2167 1,
2168 1,
2169 CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
2170 Sequence<0, 1, 2, 3>, // DimAccessOrder,
2171 3, // DstVectorDim,
2172 CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
2173 InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
2174 1, // DstScalarStrideInVector,
2175 false // DstResetCoordinateAfterRun,
2176 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
2177 make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
2178 thread_m_cluster_id,
2179 __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
2180 thread_n_cluster_id *
2181 CShuffleBlockTransferScalarPerVector_NPerBlock),
2182 CElementwiseOperation{}};
2183
2184#if 0
2185 if(threadIdx.x == 0) {
2186 printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
2187 reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
2188 __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
2189 __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
2190 }
2191#endif
2192 if(threadIdx.x == 0)
2193 {
2194 atomicAdd(&p_semaphore[reduction_idx], 1);
2195 }
2196
2197 wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
2198 using Accumulation = ck::detail::
2199 AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
2200
2201 for(int i_m = 0; i_m < MReduceIters; i_m++)
2202 {
2203 static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
2204 acc_buf.Clear();
2205 for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
2206 {
2207 auto c_partial_acc_buf =
2210 reinterpret_cast<AccDataType*>(p_workspace) +
2211 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
2212 c_partial_acc_block_m_n.GetElementSpaceSize());
2213
2214 acc_load.Run(c_partial_acc_block_m_n,
2215 c_partial_acc_buf,
2216 acc_thread_buf_load_desc,
2217 make_tuple(I0, I0),
2218 parcial_acc_buf);
2219
2221 [&](auto i_vec) {
2222 constexpr auto offset =
2223 acc_thread_buf_load_desc.CalculateOffset(
2224 make_tuple(0, i_vec));
2225 Accumulation::Calculate(acc_buf(Number<offset>{}),
2226 parcial_acc_buf[Number<offset>{}]);
2227 });
2228 }
2229
2230 if(thread_n_cluster_id *
2231 CShuffleBlockTransferScalarPerVector_NPerBlock <
2232 NPerBlock)
2233 {
2234 acc_store.Run(acc_thread_buf_store_desc,
2235 make_tuple(I0, I0, I0, I0),
2236 acc_buf,
2237 c_grid_desc_mblock_mperblock_nblock_nperblock,
2238 c_grid_buf);
2239 }
2240 if constexpr(NReduceIters != 1)
2241 {
2242 if constexpr(i_n_reduce != (NReduceIters - 1))
2243 {
2244 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
2245 partial_acc_load_step_n);
2246 acc_store.MoveDstSliceWindow(
2247 c_grid_desc_mblock_mperblock_nblock_nperblock,
2248 partial_acc_store_step_n);
2249 }
2250 else
2251 {
2252 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
2253 partial_acc_load_step_n_reverse);
2254 acc_store.MoveDstSliceWindow(
2255 c_grid_desc_mblock_mperblock_nblock_nperblock,
2256 partial_acc_store_step_n_reverse);
2257 }
2258 }
2259 });
2260 {
2261 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
2262 partial_acc_load_step_m);
2263 acc_store.MoveDstSliceWindow(
2264 c_grid_desc_mblock_mperblock_nblock_nperblock,
2265 partial_acc_store_step_m);
2266 }
2267 }
2268
2269 continue;
2270 }
2271 }
2272
2273 // offset for last acc buffer of this block
2274 uint32_t block_acc_offset =
2275 (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
2276 MPerBlock * NPerBlock;
2277 while(true)
2278 {
2279
2280 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
2281 block_2_ctile_map_streamk.get_current_iter_length(
2282 iter_start, iter_end, num_k_block_main_loop));
2283 uint32_t tile_idx, iter_offset;
2284 block_2_ctile_map_streamk.get_tile_idx_with_offset(
2285 iter_end - 1, tile_idx, iter_offset);
2286 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
2287
2288 auto block_work_idx =
2289 block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
2290
2291 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
2292 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
2293
2294 // HACK: this force m/n_block_data_idx_on_grid into SGPR
2295 const index_t m_block_data_idx_on_grid =
2296 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
2297
2298 const index_t n_block_data_idx_on_grid =
2299 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2300 const index_t k0_block_data_idx_on_grid =
2301 __builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
2302
2303 // lds max alignment
2304 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2305
2306 // A matrix in LDS memory, dst of blockwise copy
2307 constexpr auto a_block_desc_ak0_m_ak1 =
2309
2310 // B matrix in LDS memory, dst of blockwise copy
2311 constexpr auto b_block_desc_bk0_n_bk1 =
2313
2314 // A matrix blockwise copy
2315 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
2317 AElementwiseOperation,
2321 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2322 ABlockTransferThreadClusterArrangeOrder,
2323 ADataType,
2324 ADataType,
2325 decltype(a_grid_desc_ak0_m_ak1),
2326 decltype(a_block_desc_ak0_m_ak1),
2327 ABlockTransferSrcAccessOrder,
2329 ABlockTransferSrcVectorDim,
2330 2,
2331 ABlockTransferSrcScalarPerVector,
2332 ABlockTransferDstScalarPerVector_AK1,
2333 1,
2334 1,
2335 AThreadTransferSrcResetCoordinateAfterRun,
2336 true,
2337 BlockwiseGemmPipe::GlobalBufferNum>(
2338 a_grid_desc_ak0_m_ak1,
2339 make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
2340 a_element_op,
2341 a_block_desc_ak0_m_ak1,
2342 make_multi_index(0, 0, 0),
2344
2345 // B matrix blockwise copy
2346 auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
2348 BElementwiseOperation,
2352 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2353 BBlockTransferThreadClusterArrangeOrder,
2354 BDataType,
2355 BDataType,
2356 decltype(b_grid_desc_bk0_n_bk1),
2357 decltype(b_block_desc_bk0_n_bk1),
2358 BBlockTransferSrcAccessOrder,
2360 BBlockTransferSrcVectorDim,
2361 2,
2362 BBlockTransferSrcScalarPerVector,
2363 BBlockTransferDstScalarPerVector_BK1,
2364 1,
2365 1,
2366 BThreadTransferSrcResetCoordinateAfterRun,
2367 true,
2368 BlockwiseGemmPipe::GlobalBufferNum>(
2369 b_grid_desc_bk0_n_bk1,
2370 make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
2371 b_element_op,
2372 b_block_desc_bk0_n_bk1,
2373 make_multi_index(0, 0, 0),
2375
2376 // LDS allocation for A and B: be careful of alignment
2377 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2378 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2379
2380 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2381 static_cast<ADataType*>(p_shared_0),
2382 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2383
2384 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2385 static_cast<BDataType*>(p_shared_0) +
2386 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
2387 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2388
2389 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2390 static_cast<ADataType*>(p_shared_1),
2391 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2392
2393 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2394 static_cast<BDataType*>(p_shared_1) +
2395 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
2396 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2397
2398 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2399 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2400
2401 constexpr auto a_block_slice_copy_step =
2402 make_multi_index(KPerBlock / AK1Number, 0, 0);
2403 constexpr auto b_block_slice_copy_step =
2404 make_multi_index(KPerBlock / BK1Number, 0, 0);
2405
2406 // Blockwise GEMM pipeline
2407 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2408 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2409 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2410
2411 num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2412 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2413 KPerBlock);
2414
2415 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2416 a_grid_desc_ak0_m_ak1,
2417 a_block_desc_ak0_m_ak1,
2418 a_blockwise_copy,
2419 a_grid_buf,
2420 a_block_bufs,
2421 a_block_slice_copy_step,
2422 b_grid_desc_bk0_n_bk1,
2423 b_block_desc_bk0_n_bk1,
2424 b_blockwise_copy,
2425 b_grid_buf,
2426 b_block_bufs,
2427 b_block_slice_copy_step,
2428 c_thread_buf,
2429 num_k_block_main_loop);
2430
2431 // shuffle C and write out
2432 {
2433 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2434 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2435 "wrong!");
2436
2437 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2438 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2439
2440 // TODO: hacky, fix it!
2441 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2442 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2443
2444 // TODO: hacky, fix it!
2445 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2446 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2447 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2448
2449 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2450 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2451 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2452 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2453 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2454 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2455 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2456 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2457
2458 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2460
2461 constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
2463
2464 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2465 static_cast<CShuffleDataType*>(p_shared_0),
2466 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
2467 .GetElementSpaceSize());
2468
2469 auto c_partial_acc_buf =
2471 reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
2472 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
2473 .GetElementSpaceSize());
2474
2475 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2477 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2478 make_tuple(
2481 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2482 // shuffle
2483 M1, // M1 = MWave
2484 M2, // M2 * M3 * M4 = MPerXdl
2485 M3,
2486 M4)),
2489 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2490 // shuffle
2491 N1, // N1 = NWave
2492 N2))), // N2 = NPerXdl
2496 Sequence<>{},
2498
2499 // calculate origin of thread output tensor on global memory
2500 // blockwise GEMM c matrix starting index
2501 const auto c_thread_mtx_on_block =
2502 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2503
2504 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2505 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2506
2507 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2509 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2512
2513 const auto m_thread_data_on_block_idx =
2514 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2515 make_multi_index(m_thread_data_on_block));
2516
2517 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2522
2523 const auto n_thread_data_on_block_idx =
2524 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2525 make_multi_index(n_thread_data_on_block));
2526
2527 // shuffle: threadwise copy C from VGPR to LDS
2528 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2529 AccDataType,
2530 CShuffleDataType,
2531 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2532 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2534 Sequence<CShuffleMXdlPerWavePerShuffle,
2535 CShuffleNXdlPerWavePerShuffle,
2536 I1,
2537 I1,
2538 M2,
2539 I1,
2540 M4,
2541 I1>,
2543 7,
2544 1,
2546 1,
2547 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2549 0,
2550 m_thread_data_on_block_idx[I1],
2551 n_thread_data_on_block_idx[I1],
2552 m_thread_data_on_block_idx[I2],
2553 m_thread_data_on_block_idx[I3],
2554 m_thread_data_on_block_idx[I4],
2555 n_thread_data_on_block_idx[I2]),
2557 // shuffle: blockwise copy C from LDS to global
2558 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
2559 ThisThreadBlock, // ThreadGroup
2560 CElementwiseOperation, // ElementwiseOperation,
2561 // CGlobalMemoryDataOperation, // DstInMemOp,
2562 Sequence<1,
2563 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2564 1,
2565 CShuffleNXdlPerWavePerShuffle * NWave *
2566 NPerXdl>, // BlockSliceLengths,
2567 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2568 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2569 CShuffleDataType, // typename SrcData,
2570 CDataType, // typename DstData,
2571 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2572 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2573 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2574 3, // index_t VectorDim,
2575 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2576 false, // bool ThreadTransferSrcResetCoordinateAfterRun,
2577 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2578 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2579 make_multi_index(0, 0, 0, 0),
2580 c_grid_desc_mblock_mperblock_nblock_nperblock,
2581 make_multi_index(block_m_id, 0, block_n_id, 0),
2582 c_element_op};
2583
2584 // LDS to global partial acc
2585 auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
2586 ThisThreadBlock, // index_t BlockSize,
2587 CElementwiseOperation, // ElementwiseOperation,
2588 // InMemoryDataOperationEnum::Set, // DstInMemOp,
2589 Sequence<1,
2590 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2591 1,
2592 CShuffleNXdlPerWavePerShuffle * NWave *
2593 NPerXdl>, // BlockSliceLengths,
2594 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2595 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2596 CShuffleDataType, // typename SrcData,
2597 AccDataType, // typename DstData,
2598 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2599 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
2600 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2601 3, // index_t VectorDim,
2602 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2603 false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
2604 // false, othre wise has scratch
2605 false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
2606 // false, othre wise has scratch
2607 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2608 make_multi_index(0, 0, 0, 0),
2609 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
2610 make_multi_index(0, 0, 0, 0),
2611 c_element_op};
2612
2613 // space filling curve for threadwise C in VGPR
2614 constexpr auto sfc_c_vgpr =
2617 Sequence<CShuffleMXdlPerWavePerShuffle,
2618 CShuffleNXdlPerWavePerShuffle,
2619 1,
2620 1,
2621 M2,
2622 1,
2623 M4,
2624 1>>{};
2625
2626 // space filling curve for shuffled blockwise C in global mem
2627 constexpr auto sfc_c_global = SpaceFillingCurve<
2630 Sequence<1,
2631 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2632 1,
2633 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2634
2635 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2636
2637 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2638
2639 static_for<0, num_access, 1>{}([&](auto access_id) {
2640 // make sure it's safe to write to LDS
2642
2643 // each thread write its data from VGPR to LDS
2644 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2645 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2646 c_thread_buf,
2647 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2648 c_shuffle_block_buf);
2649
2650 // make sure it's safe to read from LDS
2652 c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin(
2653 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2654 make_tuple(0, 0, 0, 0));
2655
2656 if(is_dp_block)
2657 {
2658 // each block copy its data from LDS to global
2659 c_shuffle_block_copy_lds_to_global
2660 .template Run<decltype(c_shuffle_block_buf),
2661 decltype(c_grid_buf),
2663 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2664 c_shuffle_block_buf,
2665 c_grid_desc_mblock_mperblock_nblock_nperblock,
2666 c_grid_buf);
2667 }
2668 else if(is_sk_block)
2669 {
2670 if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
2671 {
2672 // each block copy its data from LDS to global
2673 c_shuffle_block_copy_lds_to_global
2674 .template Run<decltype(c_shuffle_block_buf),
2675 decltype(c_grid_buf),
2677 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2678 c_shuffle_block_buf,
2679 c_grid_desc_mblock_mperblock_nblock_nperblock,
2680 c_grid_buf);
2681 }
2682 else if(problem.reduction_strategy ==
2684 {
2685 // constexpr offset
2686 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
2687 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2688 make_tuple(0, 0, 0, 0));
2689
2690 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
2691 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
2692 make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
2693
2694 c_block_copy_lds_to_partial_acc
2695 .template Run<decltype(c_shuffle_block_buf),
2696 decltype(c_partial_acc_buf),
2698 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2699 c_shuffle_block_buf,
2700 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
2701 c_partial_acc_buf);
2702 }
2703 }
2704 if constexpr(access_id < num_access - 1)
2705 {
2706 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2707
2708 // move on C
2709 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2710 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2711 }
2712 });
2713 }
2714 // exit condition
2715 iter_end -= current_iter_length;
2716 if(iter_end <= iter_start)
2717 break;
2718 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
2719 {
2720 block_acc_offset -= MPerBlock * NPerBlock;
2721 }
2722 // make sure next loop LDS is ready for use
2724 }
2725 if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
2726 {
2727 if(is_sk_block)
2728 {
2729 // increase the counter for this tile
2730 workgroup_barrier wg_barrier(p_semaphore);
2731 wg_barrier.inc(0);
2732 }
2733 }
2734 }
2735 }
2736};
2737
2738} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto next_power_of_two()
Definition utility/math.hpp:222
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
StreamKReductionStrategy
Definition block_to_ctile_map.hpp:1011
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ GLC
Definition utility/amd_buffer_addressing.hpp:297
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Global
Definition amd_address_space.hpp:17
@ Vgpr
Definition amd_address_space.hpp:20
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:1420
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition block_to_ctile_map.hpp:1763
uint32_t dp_start_block_idx
Definition block_to_ctile_map.hpp:1431
__host__ __device__ index_t get_grid_dims() const
Definition block_to_ctile_map.hpp:1599
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition block_to_ctile_map.hpp:1617
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition block_to_ctile_map.hpp:1737
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition block_to_ctile_map.hpp:1658
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1687
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition block_to_ctile_map.hpp:1639
uint32_t reduction_start_block_idx
Definition block_to_ctile_map.hpp:1432
uint32_t sk_num_blocks
Definition block_to_ctile_map.hpp:1429
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition block_to_ctile_map.hpp:1653
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:615
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:616
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:617
BlockToCTileMap_GemmStreamK_v2< MPerBlock, NPerBlock, KPerBlock, StreamKReductionStrategy::Atomic, 8, 4 > block_2_ctile_map_streamk
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:624
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t Streamk_sel_, index_t Grid_size_, StreamKReductionStrategy reduction_strategy_)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:581
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:514
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:562
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:572
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:545
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:568
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:563
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:571
index_t N
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:560
index_t M
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:559
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:574
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:564
index_t Streamk_sel
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:565
StreamKReductionStrategy reduction_strategy
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:567
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:573
index_t Grid_size
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:566
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t Streamk_sel_, index_t Grid_size_, StreamKReductionStrategy reduction_strategy_)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:515
index_t K
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:561
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:569
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:575
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:570
__device__ SplitKBatchOffset(Problem &problem, unsigned int kbatch_id, unsigned int orig_K)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:629
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:659
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:660
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:133
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:968
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
__host__ __device__ void Clear()
Definition static_buffer.hpp:63
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition utility/workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition utility/workgroup_barrier.hpp:29
#define CK_ENV(name)
Definition utility/env.hpp:129