device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp Source File

device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp Source File#

Composable Kernel: device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp Source File
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <tuple>
9
10#include "ck/ck.hpp"
11#include "ck/utility/env.hpp"
17#include "ck/utility/tuple.hpp"
27
28namespace ck {
29namespace tensor_operation {
30namespace device {
31
32template <typename ALayout,
33 typename BLayout,
34 typename DsLayout,
35 typename ELayout,
36 typename ADataType,
37 typename BDataType,
38 typename AccDataType,
39 typename CShuffleDataType,
40 typename DsDataType,
41 typename EDataType,
42 typename AElementwiseOperation,
43 typename BElementwiseOperation,
44 typename CDEElementwiseOperation,
45 GemmSpecialization GemmSpec,
46 ck::index_t NumGemmKPrefetchStage,
47 ck::index_t BlockSize,
48 ck::index_t MPerBlock,
49 ck::index_t NPerBlock,
50 ck::index_t KPerBlock,
51 ck::index_t AK1,
52 ck::index_t BK1,
53 ck::index_t MPerXDL,
54 ck::index_t NPerXDL,
55 ck::index_t MXdlPerWave,
56 ck::index_t NXdlPerWave,
57 typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
58 typename ABlockTransferThreadClusterArrangeOrder,
59 typename ABlockTransferSrcAccessOrder,
60 index_t ABlockTransferSrcVectorDim,
61 index_t ABlockTransferSrcScalarPerVector,
62 index_t ABlockTransferDstScalarPerVector_AK1,
63 index_t ABlockLdsExtraM,
64 typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
65 typename BBlockTransferThreadClusterArrangeOrder,
66 typename BBlockTransferSrcAccessOrder,
67 index_t BBlockTransferSrcVectorDim,
68 index_t BBlockTransferSrcScalarPerVector,
69 index_t BBlockTransferDstScalarPerVector_BK1,
70 index_t BBlockLdsExtraN,
71 index_t CShuffleMXdlPerWavePerShuffle,
72 index_t CShuffleNXdlPerWavePerShuffle,
73 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
74 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
77 typename ComputeDataType = EDataType,
78 // TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
81 : public DeviceGroupedGemmSplitK<ALayout,
82 BLayout,
83 DsLayout,
84 ELayout,
85 ADataType,
86 BDataType,
87 DsDataType,
88 EDataType,
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CDEElementwiseOperation>
92{
95 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
96 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
97
98 static constexpr index_t NumDTensor = DsDataType::Size();
99
100 static constexpr auto I0 = Number<0>{};
101 static constexpr auto I1 = Number<1>{};
102 static constexpr auto I2 = Number<2>{};
103 static constexpr auto I3 = Number<3>{};
104 // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1
105 static constexpr index_t K0PerBlock = KPerBlock / AK1;
106
108 using WorkspaceDataType = float;
109
110 // First stage GridwiseGEMM kernel.
111 template <index_t NXdlPerWave_>
113 BlockSize,
114 ADataType,
115 BDataType,
116 AccDataType,
118 ALayout,
119 BLayout,
120 ELayout,
121 AElementwiseOperation,
122 BElementwiseOperation,
123 PassThrough, // CElementwiseOperation
124 GemmSpec,
125 NumGemmKPrefetchStage,
126 MPerBlock,
127 NPerBlock,
129 MPerXDL,
130 NPerXDL,
131 AK1,
132 MXdlPerWave,
133 NXdlPerWave_,
134 ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
135 ABlockTransferThreadClusterArrangeOrder,
136 ABlockTransferSrcAccessOrder,
137 ABlockTransferSrcVectorDim,
138 ABlockTransferSrcScalarPerVector,
139 ABlockTransferDstScalarPerVector_AK1,
140 false, // AThreadTransferSrcResetCoordinateAfterRun,
141 ABlockLdsExtraM,
142 BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
143 BBlockTransferThreadClusterArrangeOrder,
144 BBlockTransferSrcAccessOrder,
145 BBlockTransferSrcVectorDim,
146 BBlockTransferSrcScalarPerVector,
147 BBlockTransferDstScalarPerVector_BK1,
148 false, // BThreadTransferSrcResetCoordinateAfterRun,
149 BBlockLdsExtraN,
150 CShuffleMXdlPerWavePerShuffle,
151 CShuffleNXdlPerWavePerShuffle,
152 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
153 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
154 LoopSched,
155 PipelineVer,
156 ComputeDataType>;
159 template <typename ELay>
161 {
162 const auto c_grid_desc_m_n = [&]() {
164 {
165 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
166 }
168 {
169 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
170 }
171 }();
172
173 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
174 {
175 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
176 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
177
179 c_grid_desc_m_n,
183 }
184 else
185 {
186
188 c_grid_desc_m_n,
192 }
193 }
194
195 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
196 const std::array<index_t, NumDTensor>& NRaws,
197 const std::array<index_t, NumDTensor>& DsStride)
198 {
199 return generate_tuple(
200 [&](auto i) {
201 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
202
203 return MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
204 },
206 }
207
208 static constexpr auto MakeDsGridPointer()
209 {
210 return generate_tuple(
211 [&](auto i) {
212 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
213
214 return static_cast<const DDataType*>(nullptr);
215 },
217 }
218
219 static constexpr auto MakeElementwiseInputSequence()
220 {
222 [&]([[maybe_unused]] auto i) constexpr {
224 },
226 }
227
230 using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
231 using DsGridPointer = decltype(MakeDsGridPointer());
234
236
238 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
240 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
241
251 CDEElementwiseOperation,
252 BlockSize,
253 MPerBlock,
254 NPerBlock,
255 MPerBlock / ClusterLengthMPerBlock,
256 NPerBlock / ClusterLengthNPerBlock,
260 I1,
261 I1>;
262
263 // Block2CTileMap configuration parameter.
264 static constexpr index_t B2E_M01 = 8;
266 using GemmKernelArgument = typename GridwiseGemm64::Argument;
267
269 {
273
277 index_t block_start,
278 index_t block_end)
279 : karg_{karg},
280 block_2_ctile_map_{b2c_map},
281 block_start_{block_start},
282 block_end_{block_end}
283 {
284 }
285 };
286
287 static constexpr index_t DefaultKBatch = 1;
288
289 // Argument
290 struct Argument : public BaseArgument
291 {
292
293 Argument(std::vector<const void*>& p_As,
294 std::vector<const void*>& p_Bs,
295 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
296 std::vector<void*>& p_Es,
297 std::vector<GemmDesc>& gemm_descs,
298 AElementwiseOperation a_element_op,
299 BElementwiseOperation b_element_op,
300 CDEElementwiseOperation cde_element_op)
301 : Argument(p_As,
302 p_Bs,
303 p_Ds,
304 p_Es,
305 gemm_descs,
306 a_element_op,
307 b_element_op,
308 cde_element_op,
310 {
311 }
312
313 Argument(std::vector<const void*>& p_As,
314 std::vector<const void*>& p_Bs,
315 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
316 std::vector<void*>& p_Es,
317 std::vector<GemmDesc>& gemm_descs,
318 AElementwiseOperation a_element_op,
319 BElementwiseOperation b_element_op,
320 CDEElementwiseOperation cde_element_op,
321 index_t kbatch)
322 : K_BATCH{kbatch},
323 group_count_{0},
325 grid_size_{0},
326 a_element_op_{a_element_op},
327 b_element_op_{b_element_op},
328 cde_element_op_{cde_element_op},
329 p_Ds_{p_Ds}
330 {
331 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
332
333 if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
336 {
337 throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size");
338 }
339
345 e_ptrs_.reserve(group_count_);
346
347 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
348 {
349 const index_t M = gemm_descs[i].M_;
350 const index_t N = gemm_descs[i].N_;
351 const index_t K = gemm_descs[i].K_;
352
353 if(M == 0 || N == 0 || K == 0)
354 {
356 continue;
357 }
358
359 const index_t stride_a = gemm_descs[i].stride_A_;
360 const index_t stride_b = gemm_descs[i].stride_B_;
361 const index_t stride_e = gemm_descs[i].stride_C_;
362
363 const index_t m_padded = GridwiseGemm64::CalculateMPadded(M);
364 const index_t n_padded = GridwiseGemm64::CalculateNPadded(N);
367
368 const auto c_grid_desc_m_n =
370
371 DsGridDesc_M_N ds_grid_desc_m_n;
372 DsGridPointer p_ds_grid;
373
374 static_for<0, NumDTensor, 1>{}([&](auto j) {
375 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
376 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
377
378 p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
379 ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
380 M, N, gemm_descs[i].stride_Ds_[j]);
381 });
382 const auto local_b2c_tile_map =
383 Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
384 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
385
386 const index_t block_start = grid_size_;
387 const index_t block_end = grid_size_ + grid_size_grp;
388
389 grid_size_ += grid_size_grp;
390 group_grid_size_.push_back(grid_size_grp);
391 // block-to-e-tile map
392 auto grouped_block_2_ctile_map =
393 GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
394
395 std::array<index_t, NumDTensor> stride_ds;
396
397 static_for<0, NumDTensor, 1>{}([&](auto j) {
398 if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
399 {
400 throw std::runtime_error(
401 "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
402 }
403
404 stride_ds[j] = gemm_descs[i].stride_Ds_[j];
405 });
406 stride_Ds_.emplace_back(std::move(stride_ds));
407
408 // We first set E pointer to actual operation output, but later on
409 // when workspace will be set, this will be updated to workspace memory.
413 M,
414 N,
415 K,
416 stride_a,
417 stride_b,
418 stride_e,
419 m_padded,
420 n_padded,
421 k_padded,
422 k0_padded,
423 K_BATCH};
424
425 gemm_kernel_args_.emplace_back(
426 std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
427
428 elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
429 elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
430 ds_grid_pointer_.push_back(p_ds_grid);
431 // Store a copy of E pointers for elementwise kernel destination
432 e_ptrs_.push_back(p_Es[i]);
433 }
434 }
435
442 {
443 K_BATCH = kbatch;
444 grid_size_ = 0;
445
446 for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
447 {
448 auto& karg = gemm_kernel_args_[i].karg_;
449
450 const index_t k_padded = GridwiseGemm64::CalculateKPadded(karg.K, K_BATCH);
451 const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(karg.K, K_BATCH);
452
453 const auto c_grid_desc_m_n =
454 GridwiseGemm64::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
455
456 const auto local_b2c_tile_map =
457 Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
458 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
459
460 const index_t block_start = grid_size_;
461 const index_t block_end = grid_size_ + grid_size_grp;
462
463 grid_size_ += grid_size_grp;
464
465 // block-to-e-tile map
466 auto grouped_block_2_ctile_map =
467 GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
468
469 group_grid_size_[i] = grid_size_grp;
470 karg.KPadded = k_padded;
471 karg.K0Padded = k0_padded;
472 karg.k_batch = K_BATCH;
473 gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
474 gemm_kernel_args_[i].block_start_ = block_start;
475 gemm_kernel_args_[i].block_end_ = block_end;
476
477 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
478 {
479 index_t tiles = (block_end - block_start) / K_BATCH;
480 std::cout << "block_start: " << block_start << "\n"
481 << "block_end: " << block_end << "\n"
482 << "tiles: " << tiles << std::endl
483 << std::endl;
484
485 std::cout << "KPadded: " << karg.KPadded << std::endl
486 << "K0Padded: " << karg.K0Padded << std::endl
487 << "KBatch: " << karg.k_batch << std::endl
488 << "grid_size_: " << karg.KPadded << std::endl;
489 }
490 }
491 }
492
494 {
495 // set-up each group E pointer to it's designated workspace memory.
496 WorkspaceDataType* p_workspace = reinterpret_cast<WorkspaceDataType*>(p_workspace_);
497 std::size_t offset = 0;
498
499 for(auto& arg : gemm_kernel_args_)
500 {
501 arg.karg_.p_c_grid = p_workspace + offset;
502 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
503 offset += tiles * MPerBlock * NPerBlock;
504 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
505 {
506 std::cout << "block_start: " << arg.block_start_ << "\n"
507 << "block_end: " << arg.block_end_ << "\n"
508 << "tiles: " << tiles << "\n"
509 << "offset: " << offset << std::endl;
510 }
511 }
512 }
513
514 std::size_t GetWorkspaceSizeBytes() const
515 {
516 std::size_t size_bytes{0};
517
518 for(const auto& arg : gemm_kernel_args_)
519 {
520 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
521 size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType);
522 }
523 return size_bytes;
524 }
525
526 std::size_t GetWorkspaceSize(std::size_t group) const
527 {
528 const auto& arg = gemm_kernel_args_[group];
529 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
530 return tiles * MPerBlock * NPerBlock;
531 }
532
533 // private:
538 // Pointer to device memory with GEMM kernel arguments.
540
541 AElementwiseOperation a_element_op_;
542 BElementwiseOperation b_element_op_;
543 CDEElementwiseOperation cde_element_op_;
544
545 std::vector<std::array<const void*, NumDTensor>>& p_Ds_;
546 std::vector<std::array<index_t, NumDTensor>> stride_Ds_;
547 std::vector<GemmTransKernelArg> gemm_kernel_args_;
548 std::vector<index_t> group_grid_size_;
549
550 std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
551 std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
552 std::vector<DsGridPointer> ds_grid_pointer_;
553 std::vector<void*> e_ptrs_;
554 };
555
556 // Invoker
557 struct Invoker : public BaseInvoker
558 {
574 template <typename GridwiseGemm>
575 float Run(const Argument& arg,
576 void* dev_gemm_args,
577 void* dev_gemm_workspace,
578 const StreamConfig& stream_config = StreamConfig{})
579 {
580 auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
581 CheckArgument<GridwiseGemm>(arg, stream_config);
582
583 if(dev_gemm_args == nullptr)
584 {
585 std::ostringstream err;
586 err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
587 << ":" << __LINE__ << ", in function: " << __func__;
588 throw std::runtime_error(err.str());
589 }
590
591 if(dev_gemm_workspace == nullptr)
592 {
593 std::ostringstream err;
594 err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":"
595 << __LINE__ << ", in function: " << __func__;
596 throw std::runtime_error(err.str());
597 }
598
599 float ave_time = 0;
600
601 if(all_have_main_k_block_loop)
602 {
603 ave_time = DispatchKernel<GridwiseGemm, true>(
604 arg, dev_gemm_args, dev_gemm_workspace, stream_config);
605 }
606 else
607 {
608 ave_time = DispatchKernel<GridwiseGemm, false>(
609 arg, dev_gemm_args, dev_gemm_workspace, stream_config);
610 }
611
612 return ave_time;
613 }
614
629 template <typename GridwiseGemm>
630 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
631 {
632 if(arg.p_dev_gemm_kargs_ == nullptr)
633 {
634 std::ostringstream err;
635 err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
636 << ":" << __LINE__ << ", in function: " << __func__;
637 throw std::runtime_error(err.str());
638 }
639
640 if(arg.p_workspace_ == nullptr)
641 {
642 std::ostringstream err;
643 err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":"
644 << __LINE__ << ", in function: " << __func__;
645 throw std::runtime_error(err.str());
646 }
647
648 return Run<GridwiseGemm>(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config);
649 }
650
652
653 float Run(const BaseArgument* p_arg,
654 const StreamConfig& stream_config = StreamConfig{}) override
655 {
656 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
657 }
658
659 private:
660 template <typename GridwiseGemm>
661 auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
662 {
663 bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
664
665 {
666 const auto a_grid_desc_kbatch_ak0_m_ak1 =
667 GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(
668 arg.gemm_kernel_args_[0].karg_.M,
669 arg.gemm_kernel_args_[0].karg_.MPadded,
670 arg.gemm_kernel_args_[0].karg_.K,
671 arg.gemm_kernel_args_[0].karg_.StrideA,
672 arg.gemm_kernel_args_[0].karg_.k_batch,
673 arg.gemm_kernel_args_[0].karg_.K0Padded,
674 arg.gemm_kernel_args_[0].karg_.KPadded);
675
676 all_have_kbatch_gt_one = arg.K_BATCH > 1;
677 all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
678 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
679 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
680 }
681
682 for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
683 {
684 const auto& gemm_arg = reinterpret_cast<const typename GridwiseGemm::Argument&>(
685 arg.gemm_kernel_args_[i].karg_);
686 if(stream_config.log_level_ > 0)
687 {
688 gemm_arg.Print();
689 }
690
691 if(!GridwiseGemm::CheckValidity(gemm_arg))
692 {
693 std::ostringstream err;
694 err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
695 << ":" << __LINE__ << ", in function: " << __func__;
696 throw std::runtime_error(err.str());
697 }
698
699 const auto a_grid_desc_kbatch_ak0_m_ak1 =
700 GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
701 gemm_arg.MPadded,
702 gemm_arg.K,
703 gemm_arg.StrideA,
704 gemm_arg.k_batch,
705 gemm_arg.K0Padded,
706 gemm_arg.KPadded);
707
708 bool not_all_have_main_k_block_loop_same =
709 all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
710 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
711 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
712 bool not_all_have_kbatch_value_same =
713 all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
714
715 if(not_all_have_main_k_block_loop_same)
716 {
717 std::ostringstream err;
718 err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
719 << ":" << __LINE__ << ", in function: " << __func__;
720 throw std::runtime_error(err.str());
721 }
722
723 if(not_all_have_kbatch_value_same)
724 {
725 std::ostringstream err;
726 err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i
727 << "], kbatch: " << gemm_arg.k_batch
728 << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":"
729 << __LINE__ << ", in function: " << __func__;
730 throw std::runtime_error(err.str());
731 }
732 }
733 return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
734 }
735
736 template <typename GridwiseGemm, bool HasMainKBlockLoop>
737 float DispatchKernel(const Argument& arg,
738 void* dev_gemm_kargs,
739 void* dev_gemm_workspace,
740 const StreamConfig& stream_config) const
741 {
742 const auto gemm_kernel =
744 GemmTransKernelArg,
745 HasMainKBlockLoop,
747 AElementwiseOperation,
748 BElementwiseOperation,
750
751 const auto elementwise_kernel = kernel_elementwise<GridwiseElementwise,
753 ck::Tuple<EGridDesc_M_N>,
755 ck::Tuple<EDataType*>,
757 CDEElementwiseOperation>;
758 return LaunchKernel(gemm_kernel,
759 elementwise_kernel,
760 arg,
761 dev_gemm_kargs,
762 dev_gemm_workspace,
763 stream_config);
764 }
765
766 template <typename KernelFunction, typename KernelFunction2>
767 float LaunchKernel(const KernelFunction& gemm_kernel,
768 const KernelFunction2& elementwise_kernel,
769 const Argument& arg,
770 void* dev_gemm_kargs,
771 [[maybe_unused]] void* dev_gemm_workspace,
772 const StreamConfig& stream_config) const
773 {
774 float time{0.f};
775
777 hipMemcpyAsync(dev_gemm_kargs,
778 arg.gemm_kernel_args_.data(),
779 arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
780 hipMemcpyHostToDevice,
781 stream_config.stream_id_));
782
783 auto preprocess = [&]() {
784 hip_check_error(hipMemsetAsync(
785 dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
786 };
787
788 // GEMM kernel
790 stream_config,
791 preprocess,
792 gemm_kernel,
793 dim3(arg.grid_size_),
794 dim3(BlockSize),
795 0,
797 arg.gemm_kernel_args_.size(),
798 arg.a_element_op_,
799 arg.b_element_op_,
800 PassThrough{});
801
802 // Elementwise kernels
803 for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
804 {
806 stream_config,
807 elementwise_kernel,
808 dim3(arg.group_grid_size_[i]),
809 dim3(BlockSize),
810 0,
811 concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
812 arg.elementwise_d_grid_descs_m_n_[i]),
813 make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
814 concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
815 arg.ds_grid_pointer_[i]),
816 type_convert<EDataType*>(arg.e_ptrs_[i]),
817 Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
818 arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
819 arg.cde_element_op_);
820 }
821 return time;
822 }
823 };
824
825 static constexpr bool IsValidCompilationParameter()
826 {
827 // TODO: properly implement this check
828 return true;
829 }
830
831 static bool IsSupportedArgument(const Argument& arg)
832 {
834 {
835 return false;
836 }
839 {
840 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
841 {
842 std::cout << "The group count is not equal to sum of skipped groups "
843 "and kernel args size!"
844 << std::endl;
845 }
846 return false;
847 }
848
849 bool supported = true;
850 bool isWave64 = get_warp_size() == 64;
851 for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
852 {
853 const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
854 bool group_arg_valid = false;
855 if(isWave64)
856 {
857 if constexpr(NXdlPerWave64 > 0)
858 {
859 group_arg_valid = GridwiseGemm64::CheckValidity(gemm_arg);
860 }
861 }
862 else
863 {
864 if constexpr(NXdlPerWave32 > 0)
865 {
866 group_arg_valid = GridwiseGemm32::CheckValidity(
867 reinterpret_cast<const typename GridwiseGemm32::Argument&>(gemm_arg));
868 }
869 }
870
871 if(not group_arg_valid)
872 {
873 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
874 {
875 std::cout << "[" << __func__ << "] group id: " << i
876 << " has invalid GridwiseGemm settings!" << std::endl;
877 gemm_arg.Print();
878 }
879 }
880 supported = supported && group_arg_valid;
881 }
882 return supported;
883 }
884
885 bool IsSupportedArgument(const BaseArgument* p_arg) override
886 {
887 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
888 }
889
890 static auto MakeArgument(std::vector<const void*>& p_As,
891 std::vector<const void*>& p_Bs,
892 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
893 std::vector<void*>& p_Es,
894 std::vector<GemmDesc> gemm_descs,
895 AElementwiseOperation a_elementwise_op,
896 BElementwiseOperation b_elementwise_op,
897 CDEElementwiseOperation cde_elementwise_op)
898 {
899 return Argument{p_As,
900 p_Bs,
901 p_Ds,
902 p_Es,
903 gemm_descs,
904 a_elementwise_op,
905 b_elementwise_op,
906 cde_elementwise_op};
907 }
908
909 std::unique_ptr<BaseArgument>
910 MakeArgumentPointer(std::vector<const void*>& p_As,
911 std::vector<const void*>& p_Bs,
912 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
913 std::vector<void*>& p_Es,
914 std::vector<GemmDesc>& gemm_descs,
915 AElementwiseOperation a_elementwise_op,
916 BElementwiseOperation b_elementwise_op,
917 CDEElementwiseOperation cde_elementwise_op) override
918 {
919 return std::make_unique<Argument>(p_As,
920 p_Bs,
921 p_Ds,
922 p_Es,
923 gemm_descs,
924 a_elementwise_op,
925 b_elementwise_op,
926 cde_elementwise_op);
927 }
928
929 static auto MakeInvoker() { return Invoker{}; }
930
931 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
932 {
933 return std::make_unique<Invoker>(Invoker{});
934 }
935
936 std::string GetTypeString() const override
937 {
938 auto str = std::stringstream();
939
940 // clang-format off
941 str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage"
942 << "<"
943 << std::string(ALayout::name)[0] << ","
944 << std::string(BLayout::name)[0] << ","
945 << std::string(ELayout::name)[0] << ","
946 << BlockSize << ", "
947 << MPerBlock << ", "
948 << NPerBlock << ", "
949 << KPerBlock << ", "
950 << AK1 << ", "
951 << BK1 << ", "
952 << MPerXDL << ", "
953 << NPerXDL << ", "
954 << MXdlPerWave << ", "
955 << NXdlPerWave << ", "
956 << ABlockTransferSrcScalarPerVector << ", "
957 << BBlockTransferSrcScalarPerVector << ", "
958 << CShuffleMXdlPerWavePerShuffle << ", "
959 << CShuffleNXdlPerWavePerShuffle << ", "
960 << getGemmSpecializationString(GemmSpec) << ", "
961 << ">";
962 // clang-format on
963
964 return str.str();
965 }
966
967 void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
968 {
969 auto arg_ptr = dynamic_cast<Argument*>(p_arg);
970 if(arg_ptr)
971 {
972 arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args;
973 }
974 else
975 throw std::runtime_error(
976 "The argument pointer is not an object of "
977 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
978 }
979
980 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
981 {
982 auto arg = dynamic_cast<const Argument*>(p_arg);
983 if(arg)
984 {
985 return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
986 }
987 else
988 throw std::runtime_error(
989 "The argument pointer is not an object of "
990 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
991 }
992
993 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
994 {
995 auto arg = dynamic_cast<const Argument*>(p_arg);
996 if(arg)
997 {
998 return arg->GetWorkspaceSizeBytes();
999 }
1000 else
1001 throw std::runtime_error(
1002 "The argument pointer is not an object of "
1003 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
1004 }
1005
1007 BaseArgument* p_arg,
1008 void* p_workspace,
1009 [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override
1010 {
1011 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1012 if(p_arg_)
1013 {
1014 p_arg_->p_workspace_ = p_workspace;
1015 p_arg_->UpdateEPointers();
1016 }
1017 else
1018 throw std::runtime_error(
1019 "The argument pointer is not an object of "
1020 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
1021 }
1022
1023 [[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch)
1024 {
1025 arg.UpdateKBatch(kbatch);
1026 }
1027
1028 void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
1029 {
1030 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1031 if(p_arg_)
1032 {
1033 p_arg_->UpdateKBatch(kbatch);
1034 }
1035 else
1036 throw std::runtime_error(
1037 "The argument pointer is not an object of "
1038 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
1039 }
1040};
1041
1042} // namespace device
1043} // namespace tensor_operation
1044} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:38
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
hipStream_t stream_id_
Definition ck/stream_config.hpp:11
int log_level_
Definition ck/stream_config.hpp:13
Definition block_to_ctile_map.hpp:541
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
Definition block_to_ctile_map.hpp:872
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:269
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:271
index_t block_start_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:272
index_t block_end_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:272
GemmKernelArgument karg_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:270
GemmTransKernelArg(GemmKernelArgument &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:275
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:291
std::vector< std::array< const void *, NumDTensor > > & p_Ds_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:545
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:542
CDEElementwiseOperation cde_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:543
std::size_t GetWorkspaceSize(std::size_t group) const
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:526
index_t grid_size_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:537
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t kbatch)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:313
index_t skipped_group_count_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:536
void UpdateEPointers()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:493
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:293
void UpdateKBatch(index_t kbatch)
Set new kbatch value.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:441
std::vector< std::array< index_t, NumDTensor > > stride_Ds_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:546
index_t group_count_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:535
void * p_dev_gemm_kargs_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:539
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:547
std::vector< DsGridPointer > ds_grid_pointer_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:552
std::vector< CGridDesc_M_N > elementwise_c_grid_descs_m_n_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:550
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:541
std::vector< void * > e_ptrs_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:553
index_t K_BATCH
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:534
std::vector< DsGridDesc_M_N > elementwise_d_grid_descs_m_n_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:551
std::vector< index_t > group_grid_size_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:548
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:514
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:558
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:653
float Run(const Argument &arg, void *dev_gemm_args, void *dev_gemm_workspace, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:575
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:630
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:92
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:96
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:95
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType > GridwiseGemmBase
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:112
static constexpr auto MakeDsGridPointer()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:208
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:98
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1028
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:195
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:237
static constexpr auto I3
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:103
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:967
decltype(concat_tuple(ck::Tuple< WorkspaceDataType * >{}, DsGridPointer{})) CDDataTypes
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:233
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:885
static constexpr auto MakeElementwiseInputSequence()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:219
static constexpr auto I2
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:102
typename GridwiseGemm64::Argument GemmKernelArgument
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:266
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:890
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:265
decltype(MakeElementwiseInputSequence()) ElementwiseInputSequence
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:235
static constexpr index_t K0PerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:105
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:825
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:157
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:931
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:107
typename GridwiseGemm64::CGridDesc_M_N EGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:229
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:230
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:980
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:239
static constexpr index_t DefaultKBatch
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:287
float WorkspaceDataType
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:108
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMap
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:244
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1023
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:910
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:160
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:158
static constexpr auto I0
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:100
std::string GetTypeString() const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:936
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage DeviceOp
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:93
decltype(concat_tuple(ck::Tuple< CGridDesc_M_N >{}, DsGridDesc_M_N{})) CDGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:232
static auto MakeInvoker()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:929
typename GridwiseGemm64::CGridDesc_M_N CGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:228
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &stream_config=StreamConfig{}) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1006
static constexpr auto I1
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:101
GridwiseElementwise< CDGridDesc_M_N, ck::Tuple< EGridDesc_M_N >, CDDataTypes, ck::Tuple< EDataType * >, Block2TileMap, CDEElementwiseOperation, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 0, 1 >, ElementwiseInputSequence, ck::Sequence< CDEShuffleBlockTransferScalarPerVector_NPerBlock >, I1, I1 > GridwiseElementwise
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:245
BlockToCTileMap_KSplit_M00_N0_M01Adapt< MPerBlock, NPerBlock, CGridDesc_M_N > Block2ETileMapKSplit
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:242
static constexpr index_t B2E_M01
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:264
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:993
decltype(MakeDsGridPointer()) DsGridPointer
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:231
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:831
Definition device_grouped_gemm_splitk.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129