25template <
typename ALayout,
32 typename CShuffleDataType,
33 typename AElementwiseOperation,
34 typename BElementwiseOperation,
35 typename CElementwiseOperation,
47 typename ABlockTransferThreadClusterLengths_K0_M_K1,
48 typename ABlockTransferThreadClusterArrangeOrder,
49 typename ABlockTransferSrcAccessOrder,
53 bool ABlockLdsAddExtraM,
54 typename BBlockTransferThreadClusterLengths_K0_N_K1,
55 typename BBlockTransferThreadClusterArrangeOrder,
56 typename BBlockTransferSrcAccessOrder,
60 bool BBlockLdsAddExtraN,
61 index_t CShuffleMRepeatPerShuffle,
62 index_t CShuffleNRepeatPerShuffle,
63 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
64 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 AElementwiseOperation,
74 BElementwiseOperation,
75 CElementwiseOperation>
87 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
88 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
89 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
90 static constexpr auto MaxVectorLoadA = K1 *
sizeof(ADataType) == 16 ?
true : false;
91 static constexpr auto MaxVectorLoadB = K1 *
sizeof(BDataType) == 16 ?
true : false;
115 const auto a_grid_desc_m_k = [&]() {
118 const auto a_grid_desc_mraw_kraw =
121 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
125 const auto a_grid_desc_mraw_kraw =
128 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
132 const auto M = a_grid_desc_m_k.GetLength(
I0);
133 const auto K = a_grid_desc_m_k.GetLength(
I1);
149 constexpr auto A_KRow = 2;
151 const auto A_KWmma = K /
WmmaK;
153 const auto M0 = M / MPerBlock;
169 const auto b_grid_desc_n_k = [&]() {
172 const auto b_grid_desc_nraw_kraw =
175 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
179 const auto b_grid_desc_nraw_kraw =
182 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
186 const auto N = b_grid_desc_n_k.GetLength(
I0);
187 const auto K = b_grid_desc_n_k.GetLength(
I1);
203 constexpr auto B_KRow = 2;
205 const auto B_KWmma = K /
WmmaK;
207 const auto N0 = N / NPerBlock;
223 const auto c_grid_desc_mraw_nraw = [&]() {
236 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
256 AElementwiseOperation,
257 BElementwiseOperation,
258 CElementwiseOperation,
267 ABlockTransferThreadClusterLengths_K0_M_K1,
268 ABlockTransferThreadClusterArrangeOrder,
269 ABlockTransferSrcAccessOrder,
270 ABlockTransferSrcVectorDim,
271 ABlockTransferSrcScalarPerVector,
272 ABlockTransferDstScalarPerVector_K1,
276 BBlockTransferThreadClusterLengths_K0_N_K1,
277 BBlockTransferThreadClusterArrangeOrder,
278 BBlockTransferSrcAccessOrder,
279 BBlockTransferSrcVectorDim,
280 BBlockTransferSrcScalarPerVector,
281 BBlockTransferDstScalarPerVector_K1,
285 CShuffleMRepeatPerShuffle,
286 CShuffleNRepeatPerShuffle,
287 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
288 CShuffleBlockTransferScalarPerVector_NPerBlock,
297 const BDataType* p_b_grid,
307 AElementwiseOperation a_element_op,
308 BElementwiseOperation b_element_op,
309 CElementwiseOperation c_element_op)
376 throw std::runtime_error(
377 "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
383 const auto K = [&]() {
394 auto launch_kernel = [&](
auto has_main_k_block_loop) {
404 AElementwiseOperation,
405 BElementwiseOperation,
406 CElementwiseOperation,
408 has_main_k_block_loop>;
441 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
458 printf(
"DeviceOp err: AccDataType");
464 printf(
"DeviceOp err: Arch");
476 if(arg.
KRaw_ % ABlockTransferSrcScalarPerVector != 0)
484 if(arg.
MRaw_ % ABlockTransferSrcScalarPerVector != 0)
497 if(arg.
KRaw_ % BBlockTransferSrcScalarPerVector != 0)
505 if(arg.
NRaw_ % BBlockTransferSrcScalarPerVector != 0)
519 if(arg.
NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
543 const BDataType* p_b,
551 AElementwiseOperation a_element_op,
552 BElementwiseOperation b_element_op,
553 CElementwiseOperation c_element_op)
583 AElementwiseOperation a_element_op,
584 BElementwiseOperation b_element_op,
585 CElementwiseOperation c_element_op)
override
587 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
588 static_cast<const BDataType*
>(p_b),
589 static_cast<CDataType*
>(p_c),
606 return std::make_unique<Invoker>(
Invoker{});
612 auto str = std::stringstream();
614 std::map<LoopScheduler, std::string> LoopSchedToString{
617 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
621 str <<
"DeviceGemmWmma_CShuffle"
638 << NumPrefetch <<
", "
640 << LoopSchedToString[LoopSched] <<
", "
641 <<
"PipelineVersion: "
642 << PipelineVersionToString[PipelineVer];
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:37
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_wmma.hpp:124
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_gemm_wmma.hpp:581
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:413
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_wmma.hpp:540
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_wmma.hpp:578
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_wmma.hpp:521
ck::GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma.hpp:513
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm.hpp:22
Definition device_gemm_wmma.hpp:295
CElementwiseOperation c_element_op_
Definition device_gemm_wmma.hpp:357
AElementwiseOperation a_element_op_
Definition device_gemm_wmma.hpp:355
AGridDesc a_grid_desc_
Definition device_gemm_wmma.hpp:347
index_t M01_
Definition device_gemm_wmma.hpp:353
index_t NRaw_
Definition device_gemm_wmma.hpp:360
index_t N01_
Definition device_gemm_wmma.hpp:354
const ADataType * p_a_grid_
Definition device_gemm_wmma.hpp:344
BGridDesc b_grid_desc_k0_n_k1_
Definition device_gemm_wmma.hpp:348
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_wmma.hpp:352
index_t MRaw_
Definition device_gemm_wmma.hpp:359
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma.hpp:296
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_wmma.hpp:351
index_t KRaw_
Definition device_gemm_wmma.hpp:361
BElementwiseOperation b_element_op_
Definition device_gemm_wmma.hpp:356
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_wmma.hpp:349
CDataType * p_c_grid_
Definition device_gemm_wmma.hpp:346
const BDataType * p_b_grid_
Definition device_gemm_wmma.hpp:345
Definition device_gemm_wmma.hpp:366
DeviceGemmWmma_CShuffle::Argument Argument
Definition device_gemm_wmma.hpp:367
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_wmma.hpp:438
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma.hpp:369
Definition device_gemm_wmma.hpp:76
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_wmma.hpp:221
static constexpr auto K1Number
Definition device_gemm_wmma.hpp:85
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_wmma.hpp:113
static auto MakeInvoker()
Definition device_gemm_wmma.hpp:571
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma.hpp:537
static constexpr auto AEnableLds_manu
Definition device_gemm_wmma.hpp:104
static constexpr auto AEnableLds
Definition device_gemm_wmma.hpp:107
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_wmma.hpp:167
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma.hpp:451
static constexpr auto BEnableLds
Definition device_gemm_wmma.hpp:108
static constexpr auto I3
Definition device_gemm_wmma.hpp:80
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_wmma.hpp:242
static constexpr auto MaxVectorLoadA
Definition device_gemm_wmma.hpp:90
static constexpr auto I1
Definition device_gemm_wmma.hpp:78
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma.hpp:574
static constexpr auto AEnableLds_auto
Definition device_gemm_wmma.hpp:93
std::string GetTypeString() const override
Definition device_gemm_wmma.hpp:610
static constexpr auto I6
Definition device_gemm_wmma.hpp:83
static constexpr auto I0
Definition device_gemm_wmma.hpp:77
static constexpr auto I2
Definition device_gemm_wmma.hpp:79
static constexpr auto I5
Definition device_gemm_wmma.hpp:82
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_gemm_wmma.hpp:241
static constexpr auto MWaves
Definition device_gemm_wmma.hpp:87
static constexpr auto NWaves
Definition device_gemm_wmma.hpp:88
static constexpr auto matrix_padder
Definition device_gemm_wmma.hpp:110
static constexpr auto WmmaK
Definition device_gemm_wmma.hpp:89
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_gemm_wmma.hpp:240
static constexpr auto BEnableLds_auto
Definition device_gemm_wmma.hpp:97
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_wmma.hpp:445
GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition device_gemm_wmma.hpp:245
static constexpr auto MaxVectorLoadB
Definition device_gemm_wmma.hpp:91
static constexpr auto BEnableLds_manu
Definition device_gemm_wmma.hpp:105
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma.hpp:542
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma.hpp:604
static constexpr auto I4
Definition device_gemm_wmma.hpp:81
Definition matrix_padder.hpp:180