38template <
typename GridwiseGemm,
42 typename AElementwiseOperation,
43 typename BElementwiseOperation,
44 typename CDEElementwiseOperation,
45 typename AGridDesc_K0_M0_M1_K1,
46 typename BGridDesc_K0_N0_N1_K1,
47 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
48 typename CGridDesc_M0_M10_M11_N0_N10_N11,
49 typename ComputePtrOffsetOfBatch,
50 typename Block2CTileMap,
51 bool HasMainKBlockLoop,
52 bool HasDoubleTailKBlockLoop>
54#if CK_USE_LAUNCH_BOUNDS
58 const ABDataType* __restrict__ p_a_grid,
59 const ABDataType* __restrict__ p_b_grid,
61 EDataType* __restrict__ p_e_grid,
63 const AElementwiseOperation a_element_op,
64 const BElementwiseOperation b_element_op,
65 const CDEElementwiseOperation cde_element_op,
66 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
67 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
68 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
69 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
70 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
71 const Block2CTileMap block_2_ctile_map)
73#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \
74 defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
76 const index_t num_blocks_per_batch =
77 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
80 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
81 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
82 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
83 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
84 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
85 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
87 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
89 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 DsPointer p_ds_grid_grp;
93 static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
96 [&](
auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
98 GridwiseGemm::Run(p_a_grid + a_batch_offset,
99 p_b_grid + b_batch_offset,
101 p_e_grid + e_batch_offset,
106 a_grid_desc_k0_m0_m1_k1,
107 b_grid_desc_k0_n0_n1_k1,
108 ds_grid_desc_m0_m10_m11_n0_n10_n11,
109 e_grid_desc_m0_m10_m11_n0_n10_n11,
122 ignore = a_grid_desc_k0_m0_m1_k1;
123 ignore = b_grid_desc_k0_n0_n1_k1;
124 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
125 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
126 ignore = compute_ptr_offset_of_batch;
127 ignore = block_2_ctile_map;
132template <
typename ALayout,
138 typename AccDataType,
141 typename AElementwiseOperation,
142 typename BElementwiseOperation,
143 typename CDEElementwiseOperation,
153 typename M1N1ThreadClusterM1Xs,
154 typename M1N1ThreadClusterN1Xs,
155 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
156 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
157 typename ABlockTransferThreadClusterArrangeOrder,
158 typename ABlockTransferSrcAccessOrder,
159 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
160 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
161 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
162 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
163 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
164 typename BBlockTransferThreadClusterArrangeOrder,
165 typename BBlockTransferSrcAccessOrder,
166 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
167 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
168 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
169 typename CThreadTransferSrcDstAccessOrder,
170 index_t CThreadTransferSrcDstVectorDim,
171 index_t CThreadTransferDstScalarPerVector,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CDEElementwiseOperation>
205 const auto a_grid_desc_m_k = [&]() {
218 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
242 const auto b_grid_desc_k_n = [&]() {
255 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
275 template <
typename ELay>
278 const auto c_grid_desc_m_n = [&]() {
291 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
292 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
312 const std::array<index_t, NumDTensor>& NRaws,
313 const std::array<index_t, NumDTensor>& DsStride)
333 std::array<ck::index_t, NumDTensor> BatchStrideDs,
335 : BatchStrideA_(BatchStrideA),
336 BatchStrideB_(BatchStrideB),
337 BatchStrideDs_(BatchStrideDs),
338 BatchStrideE_(BatchStrideE)
344 return g_idx *
static_cast<long_index_t>(BatchStrideA_);
349 return g_idx *
static_cast<long_index_t>(BatchStrideB_);
354 std::array<long_index_t, NumDTensor> ds_offset;
356 ds_offset[i] = g_idx *
static_cast<long_index_t>(BatchStrideDs_[i]);
363 return g_idx *
static_cast<long_index_t>(BatchStrideE_);
369 std::array<ck::index_t, NumDTensor> BatchStrideDs_;
380 AElementwiseOperation,
381 BElementwiseOperation,
382 CDEElementwiseOperation,
394 M1N1ThreadClusterM1Xs,
395 M1N1ThreadClusterN1Xs,
396 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
397 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
398 ABlockTransferThreadClusterArrangeOrder,
399 ABlockTransferSrcAccessOrder,
400 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
401 ABlockTransferSrcVectorTensorContiguousDimOrder,
402 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
403 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
404 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
405 BBlockTransferThreadClusterArrangeOrder,
406 BBlockTransferSrcAccessOrder,
407 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
408 BBlockTransferSrcVectorTensorContiguousDimOrder,
409 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
410 CThreadTransferSrcDstAccessOrder,
411 CThreadTransferSrcDstVectorDim,
412 CThreadTransferDstScalarPerVector>;
429 const void* p_b_grid,
430 std::array<const void*, NumDTensor> p_ds_grid,
438 std::array<index_t, NumDTensor> StrideDs,
442 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
444 AElementwiseOperation a_element_op,
445 BElementwiseOperation b_element_op,
446 CDEElementwiseOperation cde_element_op)
447 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
448 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
450 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
471 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
538 std::cout <<
"arg.a_grid_desc_k0_m0_m1_k1_{"
543 std::cout <<
"arg.b_grid_desc_k0_n0_n1_k1_{"
555 throw std::runtime_error(
556 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
564 auto launch_kernel = [&](
auto has_main_k_block_loop,
565 auto has_double_tail_k_block_loop) {
566 constexpr bool has_main_loop = has_main_k_block_loop.value;
567 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
574 AElementwiseOperation,
575 BElementwiseOperation,
576 CDEElementwiseOperation,
581 ComputePtrOffsetOfStridedBatch,
609 const bool has_double_tail_k_block_loop =
612 if(has_main_k_block_loop && has_double_tail_k_block_loop)
615 integral_constant<bool, true>{});
617 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
620 integral_constant<bool, false>{});
622 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
625 integral_constant<bool, true>{});
630 integral_constant<bool, false>{});
638 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
654 pass = pass && arg.
K_ % K1 == 0;
676 std::array<const void*, NumDTensor> p_ds,
684 std::array<ck::index_t, NumDTensor> StrideDs,
688 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
690 AElementwiseOperation a_element_op,
691 BElementwiseOperation b_element_op,
692 CDEElementwiseOperation cde_element_op)
718 std::unique_ptr<BaseArgument>
721 const std::array<const void*, NumDTensor>& p_ds,
729 const std::array<ck::index_t, NumDTensor>& StrideDs,
733 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
735 AElementwiseOperation a_element_op,
736 BElementwiseOperation b_element_op,
737 CDEElementwiseOperation cde_element_op)
override
739 return std::make_unique<Argument>(p_a,
763 return std::make_unique<Invoker>(
Invoker{});
769 auto str = std::stringstream();
772 str <<
"DeviceBatchedGemmMultipleD_Dl"
777 << K0PerBlock <<
", "
779 << M1PerThread <<
", "
780 << N1PerThread <<
", "
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map)
Definition device_batched_gemm_multiple_d_dl.hpp:57
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition device_gemm_multiple_d_dl.hpp:39
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_multiple_d.hpp:60
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:242
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:178
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:158
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_dl_multiple_d.hpp:253
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateGridSize __host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_multiple_d.hpp:136
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:150
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:143
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:110
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:234
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:200
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multi_d.hpp:27
Definition device_batched_gemm_multiple_d_dl.hpp:427
EDataType * p_e_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:502
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:515
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_multiple_d_dl.hpp:520
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:509
BElementwiseOperation b_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:526
CDEElementwiseOperation cde_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:527
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_multiple_d_dl.hpp:522
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_batched_gemm_multiple_d_dl.hpp:517
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multiple_d_dl.hpp:428
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_dl.hpp:512
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:514
const ADataType * p_a_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:499
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_batched_gemm_multiple_d_dl.hpp:516
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:510
const BDataType * p_b_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:500
index_t K_
Definition device_batched_gemm_multiple_d_dl.hpp:504
index_t Batch_
Definition device_batched_gemm_multiple_d_dl.hpp:507
AElementwiseOperation a_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:525
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_dl.hpp:511
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:501
Definition device_batched_gemm_multiple_d_dl.hpp:330
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, std::array< ck::index_t, NumDTensor > BatchStrideDs, index_t BatchStrideE)
Definition device_batched_gemm_multiple_d_dl.hpp:331
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:347
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:361
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:352
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:342
Definition device_batched_gemm_multiple_d_dl.hpp:532
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_dl.hpp:535
DeviceBatchedGemmMultipleD_Dl::Argument Argument
Definition device_batched_gemm_multiple_d_dl.hpp:533
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_dl.hpp:635
Definition device_batched_gemm_multiple_d_dl.hpp:188
static constexpr index_t NumDTensor
Definition device_batched_gemm_multiple_d_dl.hpp:190
static constexpr auto I5
Definition device_batched_gemm_multiple_d_dl.hpp:197
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_batched_gemm_multiple_d_dl.hpp:422
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_batched_gemm_multiple_d_dl.hpp:374
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_dl.hpp:642
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_dl.hpp:761
static constexpr auto I3
Definition device_batched_gemm_multiple_d_dl.hpp:195
static constexpr auto I0
Definition device_batched_gemm_multiple_d_dl.hpp:192
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_batched_gemm_multiple_d_dl.hpp:201
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_batched_gemm_multiple_d_dl.hpp:276
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_gemm_multiple_d_dl.hpp:719
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_dl.hpp:648
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_batched_gemm_multiple_d_dl.hpp:326
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_dl.hpp:715
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_batched_gemm_multiple_d_dl.hpp:238
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_dl.hpp:669
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_batched_gemm_multiple_d_dl.hpp:414
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_batched_gemm_multiple_d_dl.hpp:325
static constexpr auto K1Number
Definition device_batched_gemm_multiple_d_dl.hpp:199
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_batched_gemm_multiple_d_dl.hpp:324
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_batched_gemm_multiple_d_dl.hpp:327
static constexpr auto I4
Definition device_batched_gemm_multiple_d_dl.hpp:196
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_batched_gemm_multiple_d_dl.hpp:416
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_dl.hpp:767
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) EGridDesc_M0_M10_M11_N0_N10_N11
Definition device_batched_gemm_multiple_d_dl.hpp:420
static constexpr auto I2
Definition device_batched_gemm_multiple_d_dl.hpp:194
static constexpr auto I1
Definition device_batched_gemm_multiple_d_dl.hpp:193
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_batched_gemm_multiple_d_dl.hpp:418
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multiple_d_dl.hpp:674
DeviceBatchedGemmMultipleD_Dl DeviceOp
Definition device_batched_gemm_multiple_d_dl.hpp:189
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_batched_gemm_multiple_d_dl.hpp:311