blockwise_gemm_pipeline_xdlops_mx_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10template <BlockGemmPipelineVersion BlkGemmPipelineVer,
11 BlockGemmPipelineScheduler BlkGemmPipeSche,
12 index_t ThreadBlockSize,
13 index_t ScaleBlockSize,
14 typename ADataType,
15 typename AScaleDataType,
16 typename BDataType,
17 typename BScaleDataType,
18 typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
19 // must be used for compute
20 typename AccDataType,
21 typename ATileDesc,
22 typename BTileDesc,
23 typename AMmaTileDesc,
24 typename BMmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerXDL,
31 index_t NPerXDL,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack>
36{
37
38 // Hardware MX GEMM pipeline
39 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
40 {
41 return BlockwiseGemmXdlops_pipeline_v1_mx<BlkGemmPipeSche,
42 ThreadBlockSize,
43 ScaleBlockSize,
44 ADataType,
45 AScaleDataType,
46 BDataType,
47 BScaleDataType,
48 ATileDesc,
49 BTileDesc,
50 AMmaTileDesc,
51 BMmaTileDesc,
52 ABlockTransferSrcScalarPerVector,
53 BBlockTransferSrcScalarPerVector,
54 MPerBlock,
55 NPerBlock,
56 KPerBlock,
57 MPerXDL,
58 NPerXDL,
59 MRepeat,
60 NRepeat,
61 KPack>{};
62 }
63 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
64 {
65 return BlockwiseGemmXdlops_pipeline_v3_mx<BlkGemmPipeSche,
66 ThreadBlockSize,
67 ScaleBlockSize,
68 ADataType,
69 AScaleDataType,
70 BDataType,
71 BScaleDataType,
72 ATileDesc,
73 BTileDesc,
74 AMmaTileDesc,
75 BMmaTileDesc,
76 ABlockTransferSrcScalarPerVector,
77 BBlockTransferSrcScalarPerVector,
78 MPerBlock,
79 NPerBlock,
80 KPerBlock,
81 MPerXDL,
82 NPerXDL,
83 MRepeat,
84 NRepeat,
85 KPack>{};
86 }
87 else
88 {
89 std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
90 }
91}
92
93} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_v1_mx.hpp:38
Definition blockwise_gemm_pipeline_xdlops_v3_mx.hpp:38