33template <index_t NumDTensor, index_t NumRTensor>
34struct ComputePtrOffsetOfStridedBatch
36 ComputePtrOffsetOfStridedBatch() =
default;
38 ComputePtrOffsetOfStridedBatch(
index_t BatchStrideA,
40 Array<ck::index_t, NumDTensor> BatchStrideDs,
42 Array<ck::index_t, NumRTensor> BatchStrideRs)
43 : BatchStrideA_(BatchStrideA),
44 BatchStrideB_(BatchStrideB),
45 BatchStrideDs_(BatchStrideDs),
46 BatchStrideE_(BatchStrideE),
47 BatchStrideRs_(BatchStrideRs)
61 __host__ __device__
constexpr auto GetDsPtrOffset(
index_t g_idx)
const
63 Array<long_index_t, NumDTensor> ds_offset;
64 static_for<0, NumDTensor, 1>{}(
65 [&](
auto i) { ds_offset(i) = g_idx *
static_cast<long_index_t>(BatchStrideDs_[i]); });
74 __host__ __device__
constexpr auto GetRsPtrOffset(
index_t g_idx)
const
76 Array<long_index_t, NumRTensor> rs_offset;
77 static_for<0, NumRTensor, 1>{}(
78 [&](
auto i) { rs_offset(i) = g_idx *
static_cast<long_index_t>(BatchStrideRs_[i]); });
84 Array<ck::index_t, NumDTensor> BatchStrideDs_;
86 Array<ck::index_t, NumRTensor> BatchStrideRs_;
114template <
typename GridwiseGemm,
119 typename AElementwiseOperation,
120 typename BElementwiseOperation,
121 typename CDEElementwiseOperation,
122 typename QsElementwiseOperation,
123 typename RsElementwiseOperation,
124 typename AGridDesc_AK0_M_AK1,
125 typename BGridDesc_BK0_N_BK1,
126 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
127 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
128 typename RsGridDescriptor_MBlock_MPerBlock,
129 typename Block2ETileMap,
130 typename ComputePtrOffsetOfBatch,
131 bool HasMainKBlockLoop>
133#if CK_USE_LAUNCH_BOUNDS
136 kernel_batch_gemm_multiple_d_xdl_cshuffle(
137 const ABDataType* __restrict__ p_a_grid,
138 const ABDataType* __restrict__ p_b_grid,
140 EDataType* __restrict__ p_e_grid,
142 const AElementwiseOperation a_element_op,
143 const BElementwiseOperation b_element_op,
144 const CDEElementwiseOperation cde_element_op,
145 const QsElementwiseOperation qs_element_op,
146 const RsElementwiseOperation rs_element_op,
148 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
149 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
150 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
151 ds_grid_desc_mblock_mperblock_nblock_nperblock,
152 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
153 e_grid_desc_mblock_mperblock_nblock_nperblock_,
154 const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
155 const Block2ETileMap block_2_ctile_map,
156 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
158#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
159 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
161 const index_t num_blocks_per_batch =
162 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
164 __builtin_amdgcn_readfirstlane(
get_block_1d_id() / num_blocks_per_batch);
167 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
169 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
171 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
173 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
174 const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx);
176 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
178 DsPointer p_ds_grid_grp;
180 static constexpr index_t NumDTensor =
181 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
183 static_for<0, NumDTensor, 1>{}(
184 [&](
auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
186 RsPointer p_rs_grid_grp;
188 static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size();
190 static_for<0, NumRTensor, 1>{}(
191 [&](
auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; });
193 GridwiseGemm::template Run<HasMainKBlockLoop>(
194 p_a_grid + a_batch_offset,
195 p_b_grid + b_batch_offset,
197 p_e_grid + e_batch_offset,
207 ds_grid_desc_mblock_mperblock_nblock_nperblock,
208 e_grid_desc_mblock_mperblock_nblock_nperblock_,
209 rs_grid_desc_mblock_mperblock,
219 ignore = a_grid_desc_k0_m_k1;
220 ignore = b_grid_desc_k0_n_k1;
221 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
222 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
223 ignore = rs_grid_desc_mblock_mperblock;
229 ignore = compute_ptr_offset_of_batch;
230 ignore = block_2_ctile_map;
243 typename AccDataType,
244 typename CShuffleDataType,
247 typename ReduceAccDataType,
249 typename AElementwiseOperation,
250 typename BElementwiseOperation,
251 typename CDEElementwiseOperation,
252 typename QsElementwiseOperation,
253 typename RsElementwiseOperation,
254 typename ThreadReduceOperations,
255 typename RsGlobalMemoryDataOperation,
269 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
270 typename ABlockTransferThreadClusterArrangeOrder,
271 typename ABlockTransferSrcAccessOrder,
272 index_t ABlockTransferSrcVectorDim,
273 index_t ABlockTransferSrcScalarPerVector,
274 index_t ABlockTransferDstScalarPerVector_AK1,
276 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
277 typename BBlockTransferThreadClusterArrangeOrder,
278 typename BBlockTransferSrcAccessOrder,
279 index_t BBlockTransferSrcVectorDim,
280 index_t BBlockTransferSrcScalarPerVector,
281 index_t BBlockTransferDstScalarPerVector_BK1,
283 index_t CShuffleMXdlPerWavePerShuffle,
284 index_t CShuffleNXdlPerWavePerShuffle,
285 typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
286 index_t CDEBlockTransferScalarPerVector_NPerBlock,
287 index_t RThreadTransferDstScalarPerVector_MPerBlock,
300 AElementwiseOperation,
301 BElementwiseOperation,
302 CDEElementwiseOperation,
303 RsElementwiseOperation,
304 QsElementwiseOperation>
324 template <
typename ALay>
327 const auto in_gemmmraw_gemmkraw_desc =
328 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
330 const auto in_gemmm_gemmk_desc =
333 return in_gemmm_gemmk_desc;
336 template <
typename BLay>
339 const auto wei_gemmnraw_gemmkraw_desc =
340 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
342 const auto wei_gemmn_gemmk_desc =
343 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
345 return wei_gemmn_gemmk_desc;
348 template <
typename ELay>
351 const auto out_gemmmraw_gemmnraw_desc =
352 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
354 const auto out_gemmm_gemmn_desc =
355 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
357 return out_gemmm_gemmn_desc;
360 template <
typename Descriptor>
364 const auto MPad = M - MRaw;
384 template <
typename RLay,
385 typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::GNW> ||
391 const std::array<index_t, NDimSpatial + 2>& )
393 const index_t N = r_g_n_wos_lengths[1];
397 r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
404 template <
typename RLay,
405 typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::G_NW> ||
413 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides)
415 const index_t N = r_g_n_wos_lengths[1];
417 const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2];
421 r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
423 const auto r_grid_desc_mraw =
439 template <index_t NXdlPerWave_>
448 AElementwiseOperation,
449 BElementwiseOperation,
450 CDEElementwiseOperation,
451 QsElementwiseOperation,
452 RsElementwiseOperation,
453 ThreadReduceOperations,
455 RsGlobalMemoryDataOperation,
460 NumGemmKPrefetchStage,
471 ABlockTransferThreadClusterLengths_AK0_M_AK1,
472 ABlockTransferThreadClusterArrangeOrder,
473 ABlockTransferSrcAccessOrder,
474 ABlockTransferSrcVectorDim,
475 ABlockTransferSrcScalarPerVector,
476 ABlockTransferDstScalarPerVector_AK1,
479 BBlockTransferThreadClusterLengths_BK0_N_BK1,
480 BBlockTransferThreadClusterArrangeOrder,
481 BBlockTransferSrcAccessOrder,
482 BBlockTransferSrcVectorDim,
483 BBlockTransferSrcScalarPerVector,
484 BBlockTransferDstScalarPerVector_BK1,
487 CShuffleMXdlPerWavePerShuffle,
488 CShuffleNXdlPerWavePerShuffle,
489 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
490 CDEBlockTransferScalarPerVector_NPerBlock,
491 RThreadTransferDstScalarPerVector_MPerBlock,
510 const std::array<const void*, NumDTensor>& p_ds,
512 std::array<void*, NumRTensor> p_rs,
513 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
514 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
515 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
516 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
517 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
518 ds_g_n_k_wos_lengths,
519 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
520 ds_g_n_k_wos_strides,
521 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
522 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
523 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
524 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
525 const std::array<index_t, NDimSpatial>& conv_filter_strides,
526 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
527 const std::array<index_t, NDimSpatial>& input_left_pads,
528 const std::array<index_t, NDimSpatial>& input_right_pads,
529 const AElementwiseOperation& a_element_op,
530 const BElementwiseOperation& b_element_op,
531 const CDEElementwiseOperation& cde_element_op,
532 const QsElementwiseOperation& qs_element_op,
533 const RsElementwiseOperation& rs_element_op)
534 :
p_a_grid_{static_cast<const ADataType*>(p_a)},
535 p_b_grid_{static_cast<const BDataType*>(p_b)},
546 conv_filter_dilations,
592 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds[i]);
601 ds_g_n_k_wos_lengths[i],
602 ds_g_n_k_wos_strides[i],
604 conv_filter_dilations,
618 p_rs_grid_(i) =
static_cast<RDataType*
>(p_rs[i]);
628 [&](
auto i) { std::cout <<
"Ds[M, N]: " <<
ds_grid_desc_m_n_[i] << std::endl; });
685 template <
typename Gr
idwiseGemm>
694 throw std::runtime_error(
695 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
699 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
701 ds_grid_desc_mblock_mperblock_nblock_nperblock = {};
705 rs_grid_desc_mblock_mperblock = {};
707 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
713 ds_grid_desc_mblock_mperblock_nblock_nperblock(i) =
719 static_for<0, NumRTensor, 1>{}([&](
auto i) {
720 rs_grid_desc_mblock_mperblock(i) =
732 constexpr bool has_main_loop = has_main_k_block_loop.value;
734 const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
737 typename GridwiseGemm::DsGridPointer,
739 typename GridwiseGemm::RsGridPointer,
740 AElementwiseOperation,
741 BElementwiseOperation,
742 CDEElementwiseOperation,
743 QsElementwiseOperation,
744 RsElementwiseOperation,
748 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
750 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
752 typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
755 ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor>,
776 ds_grid_desc_mblock_mperblock_nblock_nperblock,
777 e_grid_desc_mblock_mperblock_nblock_nperblock,
778 rs_grid_desc_mblock_mperblock,
783 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
798 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
839 if constexpr(ConvForwardSpecialization ==
843 for(
index_t i = 0; i < NDimSpatial; ++i)
856 else if constexpr(ConvForwardSpecialization ==
860 for(
index_t i = 0; i < NDimSpatial; ++i)
883 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
904 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
927 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
952 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1006 const std::array<const void*, NumDTensor>& p_ds,
1008 std::array<void*, NumRTensor> p_rs,
1009 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1010 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1011 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1012 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1013 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1014 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1015 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1016 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1017 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
1018 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
1019 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1020 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1021 const std::array<index_t, NDimSpatial>& input_left_pads,
1022 const std::array<index_t, NDimSpatial>& input_right_pads,
1023 const AElementwiseOperation& a_element_op,
1024 const BElementwiseOperation& b_element_op,
1025 const CDEElementwiseOperation& cde_element_op,
1026 const QsElementwiseOperation& qs_element_op,
1027 const RsElementwiseOperation& rs_element_op)
1034 a_g_n_c_wis_lengths,
1035 a_g_n_c_wis_strides,
1038 ds_g_n_k_wos_lengths,
1039 ds_g_n_k_wos_strides,
1040 e_g_n_k_wos_lengths,
1041 e_g_n_k_wos_strides,
1044 conv_filter_strides,
1045 conv_filter_dilations,
1060 const std::array<const void*, NumDTensor>& p_ds,
1062 std::array<void*, NumRTensor> p_rs,
1063 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1064 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1065 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1066 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1067 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1068 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1069 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1070 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1071 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
1072 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
1073 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1074 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1075 const std::array<index_t, NDimSpatial>& input_left_pads,
1076 const std::array<index_t, NDimSpatial>& input_right_pads,
1077 const AElementwiseOperation& a_element_op,
1078 const BElementwiseOperation& b_element_op,
1079 const CDEElementwiseOperation& cde_element_op,
1080 const QsElementwiseOperation& qs_element_op,
1081 const RsElementwiseOperation& rs_element_op)
override
1083 return std::make_unique<Argument>(p_a,
1088 a_g_n_c_wis_lengths,
1089 a_g_n_c_wis_strides,
1092 ds_g_n_k_wos_lengths,
1093 ds_g_n_k_wos_strides,
1094 e_g_n_k_wos_lengths,
1095 e_g_n_k_wos_strides,
1098 conv_filter_strides,
1099 conv_filter_dilations,
1111 return std::make_unique<Invoker>(
Invoker{});
1116 auto str = std::stringstream();
1119 str <<
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
1121 << BlockSize <<
", "
1122 << MPerBlock <<
", "
1123 << NPerBlock <<
", "
1124 << KPerBlock <<
", "
1128 << MXdlPerWave <<
", "
1129 << NXdlPerWave <<
", "
1130 << ABlockTransferSrcScalarPerVector <<
", "
1131 << BBlockTransferSrcScalarPerVector <<
", "
1132 << CShuffleMXdlPerWavePerShuffle <<
", "
1133 << CShuffleNXdlPerWavePerShuffle
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
__device__ index_t get_grid_size()
Definition get_id.hpp:49
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:74
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::RsGridPointer decltype(MakeTsGridPointer< RsDataType, false >()) RsGridPointer
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:317
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeRGridDescriptor_MBlock_MPerBlock __host__ static __device__ constexpr auto MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M &r_grid_desc_m)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:279
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:260
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::RGridDescriptor_MBlock_MPerBlock remove_cvref_t< decltype(MakeRGridDescriptor_MBlock_MPerBlock(RGridDesc_M{}))> RGridDescriptor_MBlock_MPerBlock
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:310
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::DefaultBlock2ETileMap remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:313
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeDefaultAGridDescriptor_AK0_M_AK1 __host__ static __device__ constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:174
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::DsGridPointer decltype(MakeTsGridPointer< DsDataType, true >()) DsGridPointer
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:316
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const RGridDesc_M &r_grid_desc_m, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:208
ck::GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched >::MakeDefaultBGridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:190
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:507
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:651
void Print() const
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:623
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:634
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:508
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:673
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:669
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:675
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:667
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:676
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:636
BGridDesc_N_K b_grid_desc_n_k_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:644
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:677
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:668
RGridDesc_M r_grid_desc_m_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:647
ComputePtrOffsetOfStridedBatch< NumDTensor, NumRTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:656
QsElementwiseOperation qs_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:662
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:670
EGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:645
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:646
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:637
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:661
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:672
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:640
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:674
GridwiseGemm64::RsGridPointer p_rs_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:638
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:660
Block2ETileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:654
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:659
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:635
AGridDesc_M_K a_grid_desc_m_k_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:643
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:666
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:650
RsElementwiseOperation rs_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:663
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:671
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:682
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:686
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:683
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:795
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:305
DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:496
static auto MakeRGridDescriptor_M(const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:390
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:440
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:998
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:325
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:315
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:319
static auto MakeRGridDescriptor_M(const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:412
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:493
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:499
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:317
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< DELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:434
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:311
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1057
static constexpr index_t NumRTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:312
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:309
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:429
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:802
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:316
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:337
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:432
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:321
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:430
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:308
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:494
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:349
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1055
remove_cvref_t< decltype(MakeRGridDescriptor_M< RLayout >({}, {}))> RGridDesc_M
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:436
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1109
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:314
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:503
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1003
static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:361
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1114
Definition device_grouped_conv_fwd_multiple_d_multiple_r.hpp:42
Definition matrix_padder.hpp:180