block_gemm_asmem_bsmem_creg_v1_default_policy.hpp Source File

block_gemm_asmem_bsmem_creg_v1_default_policy.hpp Source File#

Composable Kernel: block_gemm_asmem_bsmem_creg_v1_default_policy.hpp Source File
block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// Default policy for BlockGemmASmemBSmemCRegV1
13// Default policy class should not be templated, put template on member functions instead
15{
16 template <typename Problem>
18 {
19#if defined(__gfx950__)
20 constexpr bool is_a_load_tr = std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
22 constexpr bool is_b_load_tr = std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
24#else
25 constexpr bool is_a_load_tr = false;
26 constexpr bool is_b_load_tr = false;
27#endif
28 constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr)
31
32 if constexpr(((std::is_same_v<typename Problem::ADataType, half_t> &&
33 std::is_same_v<typename Problem::BDataType, half_t>) ||
34 (std::is_same_v<typename Problem::ADataType, bf16_t> &&
35 std::is_same_v<typename Problem::BDataType, bf16_t>)) &&
36 std::is_same_v<typename Problem::CDataType, float>)
37 {
38 if constexpr(get_warp_size() == 64)
39 {
40 using WG = WarpGemmDispatcher<typename Problem::ADataType,
41 typename Problem::BDataType,
42 typename Problem::CDataType,
43 32,
44 32,
45 16,
46 true,
47 false,
48 false,
49 wg_attr_num_access>;
50 return make_tuple(WG{}, 4, 1);
51 }
52 else
53 {
54 using WG = WarpGemmDispatcher<typename Problem::ADataType,
55 typename Problem::BDataType,
56 typename Problem::CDataType,
57 16,
58 16,
59 16,
60 true,
61 false,
62 false,
63 wg_attr_num_access>;
64 return make_tuple(WG{}, 4, 1);
65 }
66 }
67 else
68 {
69 static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
70 }
71 }
72};
73
74} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_asmem_bsmem_creg_v1_default_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto GetWarpGemmMWarpNWarp()
Definition block_gemm_asmem_bsmem_creg_v1_default_policy.hpp:17
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17