device_gemm_multiple_d_dl.hpp Source File

device_gemm_multiple_d_dl.hpp Source File#

Composable Kernel: device_gemm_multiple_d_dl.hpp Source File
device_gemm_multiple_d_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20
21template <typename GridwiseGemm,
22 typename ABDataType,
23 typename DsPointer,
24 typename EDataType,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CDEElementwiseOperation,
28 typename AGridDesc_K0_M0_M1_K1,
29 typename BGridDesc_K0_N0_N1_K1,
30 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
31 typename CGridDesc_M0_M10_M11_N0_N10_N11,
32 typename Block2CTileMap,
33 bool HasMainKBlockLoop,
34 bool HasDoubleTailKBlockLoop>
35__global__ void
36#if CK_USE_LAUNCH_BOUNDS
38#endif
40 const ABDataType* __restrict__ p_a_grid,
41 const ABDataType* __restrict__ p_b_grid,
42 DsPointer p_ds_grid,
43 EDataType* __restrict__ p_e_grid,
44 const AElementwiseOperation a_element_op,
45 const BElementwiseOperation b_element_op,
46 const CDEElementwiseOperation cde_element_op,
47 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
48 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
49 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
50 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
51 const Block2CTileMap block_2_ctile_map)
52{
53#if(defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) || defined(__gfx11__) || \
54 defined(__gfx12__))
55
56 constexpr index_t shared_block_size =
57 GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
58
59 __shared__ ABDataType p_shared[shared_block_size];
60
61 GridwiseGemm::Run(p_a_grid,
62 p_b_grid,
63 p_ds_grid,
64 p_e_grid,
65 p_shared,
66 a_element_op,
67 b_element_op,
68 cde_element_op,
69 a_grid_desc_k0_m0_m1_k1,
70 b_grid_desc_k0_n0_n1_k1,
71 ds_grid_desc_m0_m10_m11_n0_n10_n11,
72 e_grid_desc_m0_m10_m11_n0_n10_n11,
73 block_2_ctile_map,
76#else
77 ignore = p_a_grid;
78 ignore = p_b_grid;
79 ignore = p_ds_grid;
80 ignore = p_e_grid;
81 ignore = a_element_op;
82 ignore = b_element_op;
83 ignore = cde_element_op;
84 ignore = a_grid_desc_k0_m0_m1_k1;
85 ignore = b_grid_desc_k0_n0_n1_k1;
86 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
87 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
88 ignore = block_2_ctile_map;
89#endif
90}
91} // namespace ck
92
93namespace ck {
94namespace tensor_operation {
95namespace device {
96
97template <typename ALayout,
98 typename BLayout,
99 typename DsLayout,
100 typename ELayout,
101 typename ADataType,
102 typename BDataType,
103 typename AccDataType,
104 typename DsDataType,
105 typename EDataType,
106 typename AElementwiseOperation,
107 typename BElementwiseOperation,
108 typename CDEElementwiseOperation,
109 GemmSpecialization GemmSpec,
110 index_t BlockSize,
111 index_t MPerBlock,
112 index_t NPerBlock,
113 index_t K0PerBlock,
114 index_t K1,
115 index_t M1PerThread,
116 index_t N1PerThread,
117 index_t KPerThread,
118 typename M1N1ThreadClusterM1Xs,
119 typename M1N1ThreadClusterN1Xs,
120 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
121 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
122 typename ABlockTransferThreadClusterArrangeOrder,
123 typename ABlockTransferSrcAccessOrder,
124 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
125 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
126 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
127 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
128 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
129 typename BBlockTransferThreadClusterArrangeOrder,
130 typename BBlockTransferSrcAccessOrder,
131 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
132 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
133 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
134 typename CThreadTransferSrcDstAccessOrder,
135 index_t CThreadTransferSrcDstVectorDim,
136 index_t CThreadTransferDstScalarPerVector,
140 bool> = false>
142 BLayout,
143 DsLayout,
144 ELayout,
145 ADataType,
146 BDataType,
147 DsDataType,
148 EDataType,
149 AElementwiseOperation,
150 BElementwiseOperation,
151 CDEElementwiseOperation>
152
153{
155 static constexpr index_t NumDTensor = DsDataType::Size();
156
157 static constexpr auto I0 = Number<0>{};
158 static constexpr auto I1 = Number<1>{};
159 static constexpr auto I2 = Number<2>{};
160 static constexpr auto I3 = Number<3>{};
161 static constexpr auto I4 = Number<4>{};
162 static constexpr auto I5 = Number<5>{};
163
164 static constexpr auto K1Number = Number<K1>{};
165
167 {
168 assert(K % K1 == 0);
169
170 const index_t K0 = K / K1;
171
172 const auto a_grid_desc_m_k = [&]() {
174 {
175 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
176 }
178 {
179 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
180 }
181 }();
182
183 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
184 {
185 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
186
188 a_grid_desc_m_k,
190 make_right_pad_transform(M, PadM)),
193 }
194 else
195 {
197 a_grid_desc_m_k,
202 }
203 }
204
206 {
207 assert(K % K1 == 0);
208
209 const index_t K0 = K / K1;
210
211 const auto b_grid_desc_k_n = [&]() {
213 {
214 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
215 }
217 {
218 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
219 }
220 }();
221
222 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
223 {
224 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
225
227 b_grid_desc_k_n,
229 make_right_pad_transform(N, PadN)),
232 }
233 else
234 {
236 b_grid_desc_k_n,
241 }
242 }
243
244 template <typename ELay>
246 {
247 const auto c_grid_desc_m_n = [&]() {
249 {
250 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
251 }
253 {
254 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
255 }
256 }();
257
258 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
259 {
260 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
261 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
262
264 c_grid_desc_m_n,
268 }
269 else
270 {
271
273 c_grid_desc_m_n,
277 }
278 }
279
280 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
281 const std::array<index_t, NumDTensor>& NRaws,
282 const std::array<index_t, NumDTensor>& DsStride)
283 {
284 return generate_tuple(
285 [&](auto i) {
286 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
287
288 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
289 },
291 }
292
295 using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
297
298 // GridwiseGemm
301 ADataType,
302 AccDataType,
303 DsDataType,
304 EDataType,
305 AElementwiseOperation,
306 BElementwiseOperation,
307 CDEElementwiseOperation,
312 MPerBlock,
313 NPerBlock,
314 K0PerBlock,
315 K1,
316 M1PerThread,
317 N1PerThread,
318 KPerThread,
319 M1N1ThreadClusterM1Xs,
320 M1N1ThreadClusterN1Xs,
321 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
322 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
323 ABlockTransferThreadClusterArrangeOrder,
324 ABlockTransferSrcAccessOrder,
325 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
326 ABlockTransferSrcVectorTensorContiguousDimOrder,
327 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
328 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
329 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
330 BBlockTransferThreadClusterArrangeOrder,
331 BBlockTransferSrcAccessOrder,
332 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
333 BBlockTransferSrcVectorTensorContiguousDimOrder,
334 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
335 CThreadTransferSrcDstAccessOrder,
336 CThreadTransferSrcDstVectorDim,
337 CThreadTransferDstScalarPerVector>;
338
349
350 // Argument
351 struct Argument : public BaseArgument
352 {
353 Argument(const void* p_a_grid,
354 const void* p_b_grid,
355 std::array<const void*, NumDTensor> p_ds_grid,
356 void* p_e_grid,
357 index_t M,
358 index_t N,
359 index_t K,
360 index_t StrideA,
361 index_t StrideB,
362 std::array<index_t, NumDTensor> StrideDs,
363 index_t StrideE,
364 AElementwiseOperation a_element_op,
365 BElementwiseOperation b_element_op,
366 CDEElementwiseOperation cde_element_op)
367 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
368 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
369 p_ds_grid_{},
370 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
375 a_element_op_{a_element_op},
376 b_element_op_{b_element_op},
377 cde_element_op_{cde_element_op}
378 {
383 static_for<0, NumDTensor, 1>{}([&](auto i) {
384 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
385 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
386
387 // D pointer
388 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
389
390 // D desc
393 });
396
399 {
404
407
410
412 }
413 }
414
415 // private:
416 const ADataType* p_a_grid_;
417 const BDataType* p_b_grid_;
419 EDataType* p_e_grid_;
420
425
430
432
433 // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
434 AElementwiseOperation a_element_op_;
435 BElementwiseOperation b_element_op_;
436 CDEElementwiseOperation cde_element_op_;
437 };
438
439 // Invoker
440 struct Invoker : public BaseInvoker
441 {
443
444 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
445 {
446 {
447 std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
448 << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
449 << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
450 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
451
452 std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
453 << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
454 << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
455 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
456
457 std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
458 << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
459 }
460
463 {
464 throw std::runtime_error(
465 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
466 }
467
469 arg.e_grid_desc_m_n_.GetLength(I0), arg.e_grid_desc_m_n_.GetLength(I1));
470
471 auto launch_kernel = [&](auto has_main_k_block_loop,
472 auto has_double_tail_k_block_loop) {
473 constexpr bool has_main_loop = has_main_k_block_loop.value;
474 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
475
476 const auto kernel =
477 kernel_gemm_dl_multiple_d<GridwiseGemm,
478 ADataType,
480 EDataType,
481 AElementwiseOperation,
482 BElementwiseOperation,
483 CDEElementwiseOperation,
489 has_main_loop,
490 has_double_loop>;
491
492 return launch_and_time_kernel(stream_config,
493 kernel,
494 dim3(grid_size),
495 dim3(BlockSize),
496 0,
497 arg.p_a_grid_,
498 arg.p_b_grid_,
499 arg.p_ds_grid_,
500 arg.p_e_grid_,
501 arg.a_element_op_,
502 arg.b_element_op_,
503 arg.cde_element_op_,
509 };
510
511 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
512 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
513 const bool has_double_tail_k_block_loop =
515
516 if(has_main_k_block_loop && has_double_tail_k_block_loop)
517 {
518 return launch_kernel(integral_constant<bool, true>{},
519 integral_constant<bool, true>{});
520 }
521 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
522 {
523 return launch_kernel(integral_constant<bool, true>{},
524 integral_constant<bool, false>{});
525 }
526 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
527 {
528 return launch_kernel(integral_constant<bool, false>{},
529 integral_constant<bool, true>{});
530 }
531 else
532 {
533 return launch_kernel(integral_constant<bool, false>{},
534 integral_constant<bool, false>{});
535 }
536 }
537
538 // polymorphic
539 float Run(const BaseArgument* p_arg,
540 const StreamConfig& stream_config = StreamConfig{}) override
541 {
542 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
543 }
544 };
545
546 static constexpr bool IsValidCompilationParameter()
547 {
548 // TODO: properly implement this check
549 return true;
550 }
551
552 static bool IsSupportedArgument(const Argument& arg)
553 {
554 if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
556 {
559 }
560 else
561 {
562 return false;
563 }
564 }
565
566 // polymorphic
567 bool IsSupportedArgument(const BaseArgument* p_arg) override
568 {
569 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
570 }
571
572 static auto MakeArgument(const void* p_a,
573 const void* p_b,
574 std::array<const void*, NumDTensor> p_ds,
575 void* p_e,
576 index_t M,
577 index_t N,
578 index_t K,
579 index_t StrideA,
580 index_t StrideB,
581 std::array<ck::index_t, NumDTensor> StrideDs,
582 index_t StrideE,
583 AElementwiseOperation a_element_op,
584 BElementwiseOperation b_element_op,
585 CDEElementwiseOperation cde_element_op)
586 {
587 return Argument{p_a,
588 p_b,
589 p_ds,
590 p_e,
591 M,
592 N,
593 K,
594 StrideA,
595 StrideB,
596 StrideDs,
597 StrideE,
598 a_element_op,
599 b_element_op,
600 cde_element_op};
601 }
602
603 static auto MakeInvoker() { return Invoker{}; }
604
605 // polymorphic
606 std::unique_ptr<BaseArgument>
607 MakeArgumentPointer(const void* p_a,
608 const void* p_b,
609 std::array<const void*, NumDTensor> p_ds,
610 void* p_e,
611 index_t M,
612 index_t N,
613 index_t K,
614 index_t StrideA,
615 index_t StrideB,
616 std::array<ck::index_t, NumDTensor> StrideDs,
617 index_t StrideE,
618 AElementwiseOperation a_element_op,
619 BElementwiseOperation b_element_op,
620 CDEElementwiseOperation cde_element_op) override
621 {
622 return std::make_unique<Argument>(p_a,
623 p_b,
624 p_ds,
625 p_e,
626 M,
627 N,
628 K,
629 StrideA,
630 StrideB,
631 StrideDs,
632 StrideE,
633 a_element_op,
634 b_element_op,
635 cde_element_op);
636 }
637
638 // polymorphic
639 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
640 {
641 return std::make_unique<Invoker>(Invoker{});
642 }
643
644 // polymorphic
645 std::string GetTypeString() const override
646 {
647 auto str = std::stringstream();
648
649 // clang-format off
650 str << "DeviceGemmMultipleD_Dl"
651 << "<"
652 << BlockSize << ", "
653 << MPerBlock << ", "
654 << NPerBlock << ", "
655 << K0PerBlock << ", "
656 << K1 << ", "
657 << M1PerThread << ", "
658 << N1PerThread << ", "
659 << KPerThread
660 << ">";
661 // clang-format on
662
663 return str.str();
664 }
665};
666
667} // namespace device
668} // namespace tensor_operation
669} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition device_gemm_multiple_d_dl.hpp:39
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_multiple_d.hpp:60
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_dl.hpp:352
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_dl.hpp:418
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_dl.hpp:417
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_dl.hpp:435
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_gemm_multiple_d_dl.hpp:427
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_gemm_multiple_d_dl.hpp:426
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_dl.hpp:436
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_gemm_multiple_d_dl.hpp:422
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_gemm_multiple_d_dl.hpp:421
EDataType * p_e_grid_
Definition device_gemm_multiple_d_dl.hpp:419
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_dl.hpp:353
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_multiple_d_dl.hpp:431
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_dl.hpp:424
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_dl.hpp:416
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_multiple_d_dl.hpp:428
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_dl.hpp:423
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_dl.hpp:434
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_multiple_d_dl.hpp:429
Definition device_gemm_multiple_d_dl.hpp:441
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_dl.hpp:444
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_dl.hpp:539
DeviceGemmMultipleD_Dl::Argument Argument
Definition device_gemm_multiple_d_dl.hpp:442
Definition device_gemm_multiple_d_dl.hpp:153
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_gemm_multiple_d_dl.hpp:347
std::string GetTypeString() const override
Definition device_gemm_multiple_d_dl.hpp:645
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_gemm_multiple_d_dl.hpp:293
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_dl.hpp:546
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_dl.hpp:280
static constexpr auto I4
Definition device_gemm_multiple_d_dl.hpp:161
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_dl.hpp:155
static constexpr auto K1Number
Definition device_gemm_multiple_d_dl.hpp:164
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_gemm_multiple_d_dl.hpp:294
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_dl.hpp:552
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_dl.hpp:572
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_gemm_multiple_d_dl.hpp:299
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_gemm_multiple_d_dl.hpp:245
static constexpr auto I3
Definition device_gemm_multiple_d_dl.hpp:160
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_gemm_multiple_d_dl.hpp:339
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_dl.hpp:639
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_gemm_multiple_d_dl.hpp:205
DeviceGemmMultipleD_Dl DeviceOp
Definition device_gemm_multiple_d_dl.hpp:154
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_gemm_multiple_d_dl.hpp:341
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_dl.hpp:567
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_gemm_multiple_d_dl.hpp:295
static constexpr auto I0
Definition device_gemm_multiple_d_dl.hpp:157
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_gemm_multiple_d_dl.hpp:166
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_dl.hpp:607
static constexpr auto I2
Definition device_gemm_multiple_d_dl.hpp:159
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_multiple_d_dl.hpp:343
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_dl.hpp:296
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) EGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_multiple_d_dl.hpp:345
static constexpr auto I1
Definition device_gemm_multiple_d_dl.hpp:158
static auto MakeInvoker()
Definition device_gemm_multiple_d_dl.hpp:603
static constexpr auto I5
Definition device_gemm_multiple_d_dl.hpp:162
Definition device_gemm_multiple_d.hpp:36