blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11template <BlockGemmPipelineVersion BlkGemmPipelineVer,
12 BlockGemmPipelineScheduler BlkGemmPipeSche,
13 index_t ThreadBlockSize,
14 index_t ScaleBlockSize,
15 typename ADataType,
16 typename AScaleDataType,
17 typename BDataType,
18 typename BScaleDataType,
19 typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
20 // must be used for compute
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPack,
36 bool GUFusion = false>
38{
39
40 // Hardware MX GEMM pipeline
41 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
42 {
43 if constexpr(GUFusion)
44 {
45 return nullptr;
46 }
47 else
48 {
50 BlkGemmPipeSche,
51 ThreadBlockSize,
52 ScaleBlockSize,
53 ADataType,
54 AScaleDataType,
55 BDataType,
56 BScaleDataType,
57 ATileDesc,
58 BTileDesc,
59 AMmaTileDesc,
60 BMmaTileDesc,
61 ABlockTransferSrcScalarPerVector,
62 BBlockTransferSrcScalarPerVector,
63 MPerBlock,
64 NPerBlock,
65 KPerBlock,
66 MPerXDL,
67 NPerXDL,
68 MRepeat,
69 NRepeat,
70 KPack>{};
71 }
72 }
73 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
74 {
75 if constexpr(GUFusion)
76 {
78 BlkGemmPipeSche,
79 ThreadBlockSize,
80 ScaleBlockSize,
81 ADataType,
82 AScaleDataType,
83 BDataType,
84 BScaleDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>{};
99 }
100 else
101 {
103 BlkGemmPipeSche,
104 ThreadBlockSize,
105 ScaleBlockSize,
106 ADataType,
107 AScaleDataType,
108 BDataType,
109 BScaleDataType,
110 ATileDesc,
111 BTileDesc,
112 AMmaTileDesc,
113 BMmaTileDesc,
114 ABlockTransferSrcScalarPerVector,
115 BBlockTransferSrcScalarPerVector,
116 MPerBlock,
117 NPerBlock,
118 KPerBlock,
119 MPerXDL,
120 NPerXDL,
121 MRepeat,
122 NRepeat,
123 KPack>{};
124 }
125 }
126 else
127 {
128 std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
129 }
130}
131
132} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp:38
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp:38
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp:38