device_gemm_xdl_skip_b_lds.hpp Source File

device_gemm_xdl_skip_b_lds.hpp Source File#

Composable Kernel: device_gemm_xdl_skip_b_lds.hpp Source File
device_gemm_xdl_skip_b_lds.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
17
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ADataType,
26 typename BDataType,
27 typename CDataType,
28 typename AccDataType,
29 typename ALayout,
30 typename BLayout,
31 typename CLayout,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
35 GemmSpecialization GemmSpec,
36 ck::index_t BlockSize,
37 ck::index_t MPerBlock,
38 ck::index_t NPerBlock,
39 ck::index_t K0PerBlock,
40 ck::index_t K1,
41 ck::index_t MPerXDL,
42 ck::index_t NPerXDL,
43 ck::index_t MXdlPerWave,
44 ck::index_t NXdlPerWave,
45 typename ABlockTransferThreadClusterLengths_K0_M_K1,
46 typename ABlockTransferThreadClusterArrangeOrder,
47 typename ABlockTransferSrcAccessOrder,
48 ck::index_t ABlockTransferSrcVectorDim,
49 ck::index_t ABlockTransferSrcScalarPerVector,
50 ck::index_t ABlockTransferDstScalarPerVector_K1,
51 bool ABlockLdsAddExtraM,
52 ck::index_t BBlockTransferSrcScalarPerVector,
53 ck::index_t BBlockBufferSize,
54 ck::index_t CThreadTransferSrcDstVectorDim,
55 ck::index_t CThreadTransferDstScalarPerVector>
56struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
57 BLayout,
58 CLayout,
59 ADataType,
60 BDataType,
61 CDataType,
62 AElementwiseOperation,
63 BElementwiseOperation,
64 CElementwiseOperation>
65{
67 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
68 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
69
70 static constexpr auto I0 = Number<0>{};
71 static constexpr auto I1 = Number<1>{};
72 static constexpr auto I2 = Number<2>{};
73
74 static constexpr auto K1Number = Number<K1>{};
75 static_assert(BBlockBufferSize >= 2);
76
78 {
79 assert(K % K1 == 0);
80
81 const index_t K0 = K / K1;
82
83 const auto a_grid_desc_m_k = [&]() {
85 {
87 }
89 {
91 }
92 }();
93
94 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
95 {
96 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
97
99 a_grid_desc_m_k,
101 make_right_pad_transform(M, PadM)),
104 }
105 else
106 {
108 a_grid_desc_m_k,
113 }
114 }
115
117 {
118 assert(K % K1 == 0);
119
120 const index_t K0 = K / K1;
121
122 const auto b_grid_desc_k_n = [&]() {
124 {
125 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
126 }
128 {
129 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
130 }
131 }();
132
133 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
134 {
135 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
136
138 b_grid_desc_k_n,
140 make_right_pad_transform(N, PadN)),
143 }
144 else
145 {
147 b_grid_desc_k_n,
152 }
153 }
154
156 {
157 const auto c_grid_desc_m_n = [&]() {
159 {
160 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
161 }
163 {
164 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
165 }
166 }();
167
168 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
169 {
170 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
171 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
172
174 c_grid_desc_m_n,
178 }
179 else
180 {
181
183 c_grid_desc_m_n,
187 }
188 }
189
192 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
193
194 // GridwiseGemm
195 template <index_t NXdlPerWave_>
197 BlockSize,
198 ADataType, // TODO: distinguish A/B datatype
199 AccDataType,
200 CDataType,
205 AElementwiseOperation,
206 BElementwiseOperation,
207 CElementwiseOperation,
208 MPerBlock,
209 NPerBlock,
210 K0PerBlock,
211 MPerXDL,
212 NPerXDL,
213 K1,
214 MXdlPerWave,
215 NXdlPerWave_,
216 ABlockTransferThreadClusterLengths_K0_M_K1,
217 ABlockTransferThreadClusterArrangeOrder,
218 ABlockTransferSrcAccessOrder,
219 ABlockTransferSrcVectorDim,
220 ABlockTransferSrcScalarPerVector,
221 ABlockTransferDstScalarPerVector_K1,
222 false, // AThreadTransferSrcResetCoordinateAfterRun,
223 ABlockLdsAddExtraM,
224 BBlockTransferSrcScalarPerVector,
225 false, // BThreadTransferSrcResetCoordinateAfterRun,
226 BBlockBufferSize,
227 Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
228 CThreadTransferSrcDstVectorDim,
229 CThreadTransferDstScalarPerVector>;
232
233 // Argument
234 struct Argument : public BaseArgument
235 {
236 Argument(const ADataType* p_a_grid,
237 const BDataType* p_b_grid,
238 CDataType* p_c_grid,
239 index_t M,
240 index_t N,
241 index_t K,
242 index_t StrideA,
243 index_t StrideB,
244 index_t StrideC,
245 index_t M01,
246 index_t N01,
247 AElementwiseOperation a_element_op,
248 BElementwiseOperation b_element_op,
249 CElementwiseOperation c_element_op)
250 : p_a_grid_{p_a_grid},
251 p_b_grid_{p_b_grid},
252 p_c_grid_{p_c_grid},
256 M01_{M01},
257 N01_{N01},
258 a_element_op_{a_element_op},
259 b_element_op_{b_element_op},
260 c_element_op_{c_element_op}
261 {
267 }
268
269 // private:
270 const ADataType* p_a_grid_;
271 const BDataType* p_b_grid_;
272 CDataType* p_c_grid_;
278 AElementwiseOperation a_element_op_;
279 BElementwiseOperation b_element_op_;
280 CElementwiseOperation c_element_op_;
281 };
282
283 // Invoker
284 struct Invoker : public BaseInvoker
285 {
287
288 template <typename GridwiseGemm>
289 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
290 {
291 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
292 {
293 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
294 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
295 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
296
297 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
298 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
299 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
300
301 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
302 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
303 }
304
305 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
308 arg.M01_,
309 arg.N01_))
310 {
311 throw std::runtime_error(
312 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
313 }
314 auto block_2_ctile_map =
315 GridwiseGemm::MakeDefaultBlock2CTileMap(arg.c_grid_desc_m_n_, arg.M01_, arg.N01_);
316 const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
317
318 const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
319
320 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
321
322 float ave_time = 0;
323
324 if(has_main_k0_block_loop)
325 {
326 const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
327 GridwiseGemm,
328 ADataType, // TODO: distiguish A/B datatype
329 CDataType,
333 AElementwiseOperation,
334 BElementwiseOperation,
335 CElementwiseOperation,
337 true>;
338
339 ave_time = launch_and_time_kernel(stream_config,
340 kernel,
341 dim3(grid_size),
342 dim3(BlockSize),
343 0,
344 arg.p_a_grid_,
345 arg.p_b_grid_,
346 arg.p_c_grid_,
350 arg.a_element_op_,
351 arg.b_element_op_,
352 arg.c_element_op_,
353 block_2_ctile_map);
354 }
355 else
356 {
357 const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
358 GridwiseGemm,
359 ADataType, // TODO: distiguish A/B datatype
360 CDataType,
364 AElementwiseOperation,
365 BElementwiseOperation,
366 CElementwiseOperation,
368 false>;
369
370 ave_time = launch_and_time_kernel(stream_config,
371 kernel,
372 dim3(grid_size),
373 dim3(BlockSize),
374 0,
375 arg.p_a_grid_,
376 arg.p_b_grid_,
377 arg.p_c_grid_,
381 arg.a_element_op_,
382 arg.b_element_op_,
383 arg.c_element_op_,
384 block_2_ctile_map);
385 }
386
387 return ave_time;
388 }
389
391
392 // polymorphic
393 float Run(const BaseArgument* p_arg,
394 const StreamConfig& stream_config = StreamConfig{}) override
395 {
396 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
397 }
398 };
399
400 static constexpr bool IsValidCompilationParameter()
401 {
402 // TODO: properly implement this check
403 return true;
404 }
405
406 static bool IsSupportedArgument(const Argument& arg)
407 {
409 {
410 return false;
411 }
412 if(get_warp_size() == 64)
413 {
414 if constexpr(NXdlPerWave64 > 0)
415 {
419 arg.M01_,
420 arg.N01_);
421 }
422 }
423 else
424 {
425 if constexpr(NXdlPerWave32 > 0)
426 {
430 arg.M01_,
431 arg.N01_);
432 }
433 }
434 return false;
435 }
436
437 // polymorphic
438 bool IsSupportedArgument(const BaseArgument* p_arg) override
439 {
440 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
441 }
442
443 static auto MakeArgument(const ADataType* p_a,
444 const BDataType* p_b,
445 CDataType* p_c,
446 index_t M,
447 index_t N,
448 index_t K,
449 index_t StrideA,
450 index_t StrideB,
451 index_t StrideC,
452 AElementwiseOperation a_element_op,
453 BElementwiseOperation b_element_op,
454 CElementwiseOperation c_element_op)
455 {
456 return Argument{p_a,
457 p_b,
458 p_c,
459 M,
460 N,
461 K,
462 StrideA,
463 StrideB,
464 StrideC,
465 1,
466 1,
467 a_element_op,
468 b_element_op,
469 c_element_op};
470 }
471
472 static auto MakeInvoker() { return Invoker{}; }
473
474 // polymorphic
475 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
476 const void* p_b,
477 void* p_c,
478 index_t M,
479 index_t N,
480 index_t K,
481 index_t StrideA,
482 index_t StrideB,
483 index_t StrideC,
484 AElementwiseOperation a_element_op,
485 BElementwiseOperation b_element_op,
486 CElementwiseOperation c_element_op) override
487 {
488 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
489 static_cast<const BDataType*>(p_b),
490 static_cast<CDataType*>(p_c),
491 M,
492 N,
493 K,
494 StrideA,
495 StrideB,
496 StrideC,
497 1,
498 1,
499 a_element_op,
500 b_element_op,
501 c_element_op);
502 }
503
504 // polymorphic
505 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
506 {
507 return std::make_unique<Invoker>(Invoker{});
508 }
509
510 // polymorphic
511 std::string GetTypeString() const override
512 {
513 auto str = std::stringstream();
514
515 // clang-format off
516 str << "DeviceGemmXdlSkipBLds"
517 << "<"
518 << BlockSize << ", "
519 << MPerBlock << ", "
520 << NPerBlock << ", "
521 << K0PerBlock << ", "
522 << K1 << ", "
523 << MPerXDL << ", "
524 << NPerXDL << ", "
525 << MXdlPerWave << ", "
526 << NXdlPerWave
527 << ">";
528 // clang-format on
529
530 return str.str();
531 }
532};
533
534} // namespace device
535} // namespace tensor_operation
536} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__global__ void kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_skip_b_lds_v1.hpp:34
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_skip_b_lds_v1.hpp:117
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_gemm.hpp:22
Definition device_gemm_xdl_skip_b_lds.hpp:235
CDataType * p_c_grid_
Definition device_gemm_xdl_skip_b_lds.hpp:272
index_t N01_
Definition device_gemm_xdl_skip_b_lds.hpp:277
const BDataType * p_b_grid_
Definition device_gemm_xdl_skip_b_lds.hpp:271
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_gemm_xdl_skip_b_lds.hpp:273
CElementwiseOperation c_element_op_
Definition device_gemm_xdl_skip_b_lds.hpp:280
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_xdl_skip_b_lds.hpp:275
BElementwiseOperation b_element_op_
Definition device_gemm_xdl_skip_b_lds.hpp:279
const ADataType * p_a_grid_
Definition device_gemm_xdl_skip_b_lds.hpp:270
AElementwiseOperation a_element_op_
Definition device_gemm_xdl_skip_b_lds.hpp:278
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_gemm_xdl_skip_b_lds.hpp:274
index_t M01_
Definition device_gemm_xdl_skip_b_lds.hpp:276
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_skip_b_lds.hpp:236
Definition device_gemm_xdl_skip_b_lds.hpp:285
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_skip_b_lds.hpp:393
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_skip_b_lds.hpp:289
DeviceGemmXdlSkipBLds::Argument Argument
Definition device_gemm_xdl_skip_b_lds.hpp:286
Definition device_gemm_xdl_skip_b_lds.hpp:65
static constexpr auto I2
Definition device_gemm_xdl_skip_b_lds.hpp:72
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_xdl_skip_b_lds.hpp:192
static auto MakeInvoker()
Definition device_gemm_xdl_skip_b_lds.hpp:472
static constexpr auto I1
Definition device_gemm_xdl_skip_b_lds.hpp:71
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_skip_b_lds.hpp:68
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferSrcScalarPerVector, false, BBlockBufferSize, Sequence< 0, 2, 4, 5, 6, 1, 3, 7 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_gemm_xdl_skip_b_lds.hpp:196
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_gemm_xdl_skip_b_lds.hpp:191
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_skip_b_lds.hpp:438
std::string GetTypeString() const override
Definition device_gemm_xdl_skip_b_lds.hpp:511
static constexpr auto I0
Definition device_gemm_xdl_skip_b_lds.hpp:70
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_gemm_xdl_skip_b_lds.hpp:116
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_skip_b_lds.hpp:406
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_gemm_xdl_skip_b_lds.hpp:190
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_skip_b_lds.hpp:67
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_skip_b_lds.hpp:443
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition device_gemm_xdl_skip_b_lds.hpp:155
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_skip_b_lds.hpp:230
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_gemm_xdl_skip_b_lds.hpp:77
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_xdl_skip_b_lds.hpp:475
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_skip_b_lds.hpp:505
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_skip_b_lds.hpp:231
static constexpr auto K1Number
Definition device_gemm_xdl_skip_b_lds.hpp:74
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_skip_b_lds.hpp:400
#define CK_ENV(name)
Definition utility/env.hpp:129