gemm_pipeline_ag_bg_cr_mem.hpp Source File

gemm_pipeline_ag_bg_cr_mem.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_mem.hpp Source File
gemm_pipeline_ag_bg_cr_mem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14// A Tile Window: global memory
15// B Tile Window: global memory
16// C Distributed tensor: register
17template <typename Problem>
19{
23
24 static constexpr index_t APackedSize =
26 static constexpr index_t BPackedSize =
28
29 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
30
31 static constexpr index_t BlockSize = Problem::kBlockSize;
32 static constexpr index_t MPerBlock = BlockGemmShape::kM;
33 static constexpr index_t NPerBlock = BlockGemmShape::kN;
34 static constexpr index_t KPerBlock = BlockGemmShape::kK;
35
36 // TODO: Is this 32K value gfx9 arch specific?
37 static constexpr index_t MinMemInFlyBytes = 32768;
38
39 static constexpr index_t WgpPerCU =
40 (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
43 (MPerBlock * sizeof(ADataType) / APackedSize +
44 NPerBlock * sizeof(BDataType) / BPackedSize) *
45 KPerBlock);
46 static constexpr index_t PrefetchStages =
49 : 2;
50
51 static constexpr index_t LocalPrefillStages = 1;
53 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
54
55 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
56 {
57 return num_loop > PrefetchStages;
58 }
59
61 {
62 if(num_loop % PrefetchStages == 1)
63 {
64 return TailNumber::One;
65 }
66 else if(num_loop % PrefetchStages == 2)
67 {
68 return TailNumber::Two;
69 }
70 else if(num_loop % PrefetchStages == 3)
71 {
72 return TailNumber::Three;
73 }
74 else if(num_loop % PrefetchStages == 4)
75 {
76 return TailNumber::Four;
77 }
78 else if(num_loop % PrefetchStages == 5)
79 {
80 return TailNumber::Five;
81 }
82 else if(num_loop % PrefetchStages == 6)
83 {
84 return TailNumber::Six;
85 }
86 else if(num_loop % PrefetchStages == 7)
87 {
88 return TailNumber::Seven;
89 }
90 else
91 {
92 return TailNumber::Full;
93 }
94 }
95
96 template <typename RunFunction>
97 CK_TILE_HOST_DEVICE static auto
98 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
99 {
100 // Wrap the hot_loop dispatch first.
101 auto tail_dispatch = [&](auto tail_num_constant) {
102 if(has_hot_loop)
103 {
104 return run_func(bool_constant<true>{}, tail_num_constant);
105 }
106 else
107 {
108 return run_func(bool_constant<false>{}, tail_num_constant);
109 }
110 };
111
112#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
113 else if(tail_number == TailNumber::TAIL_NUMBER) \
114 { \
115 if constexpr(PrefetchStages > PREFETCH_VALUE) \
116 { \
117 return tail_dispatch(integral_constant<TailNumber, TailNumber::TAIL_NUMBER>{}); \
118 } \
119 }
120 // Handle all the valid cases.
121 if(tail_number == TailNumber::One)
122 {
124 }
125 else if(tail_number == TailNumber::Full)
126 {
128 }
135#undef CHECK_TAIL_NUMBER
136
137 // We shouldn't get here unless we have a tail number larger than the prefetch stages.
138#if defined(__HIP_DEVICE_COMPILE__)
139 __builtin_unreachable();
140#else
141 throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
142 "PrefetchStages are supported.");
143#endif
144 }
145};
146
147// Maximum Global Memory throughput pipeline with >=32KB data in fly
148// GlobalPrefetchStages: >=2
149// LocalPreFillStages: 1
150// LocalPreFetchStages: 0
151// LocalSharedMemoryBuffer: 1
152template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
154{
157
161
165
169
172
175 static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
177
178 using I0 = number<0>;
179 using I1 = number<1>;
180 using I2 = number<2>;
181
182 static constexpr index_t MPerBlock = BlockGemmShape::kM;
183 static constexpr index_t NPerBlock = BlockGemmShape::kN;
184 static constexpr index_t KPerBlock = BlockGemmShape::kK;
185
186 template <bool IsWave32Host = false>
187 static constexpr index_t GetVectorSizeA()
188 {
189 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
190 }
191 template <bool IsWave32Host = false>
192 static constexpr index_t GetVectorSizeB()
193 {
194 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
195 }
196 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
197
198 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
199 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
200
201 static constexpr bool kPadM = Problem::kPadM;
202 static constexpr bool kPadN = Problem::kPadN;
203 static constexpr bool kPadK = Problem::kPadK;
204
205 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
206 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
207 static constexpr index_t Preshuffle = Problem::Preshuffle;
208
209 // Where is the right place for HasHotLoop and TailNum ???
210 static constexpr bool HasHotLoop = Problem::HasHotLoop;
211 static constexpr auto TailNum = Problem::TailNum;
212 static constexpr auto Scheduler = Problem::Scheduler;
213
216
217 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
218 {
219 // clang-format off
220 return concat('_', "pipeline_AgBgCrMe",
223 concat('x', kPadM, kPadN, kPadK));
224 // clang-format on
225 }
226
228
230 {
231 return Policy::template GetSmemSize<Problem>();
232 }
233
234 template <GemmPipelineScheduler Scheduler>
236 {
237 };
238
239 template <>
241 {
243
244 template <bool HasHotLoop,
246 typename AsDramBlockWindowTmp,
247 typename BsDramBlockWindowTmp,
248 typename AElementFunction,
249 typename BElementFunction,
250 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
252 bool>* = nullptr>
253 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
254 const AElementFunction& a_element_func,
255 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
256 const BElementFunction& b_element_func,
257 index_t num_loop,
258 void* p_smem) const
259 {
260 using ADramBlockWindowTmp =
261 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
262 using BDramBlockWindowTmp =
263 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
264
265 static_assert(
266 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
267 std::is_same_v<BDataType,
269 "A/B Dram block window should have the same data type as appropriate "
270 "([A|B]DataType) defined in Problem definition!");
271
272 constexpr bool is_a_col_major =
273 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
274 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
275
276 static_assert(is_a_col_major
277 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
278 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
279 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
280 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
281 "A block window has incorrect lengths for defined ALayout!");
282 static_assert(is_b_row_major
283 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
284 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
285 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
286 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
287 "B block window has incorrect lengths for defined BLayout!");
288
289 // ------------------------------------------------------------------------------------
290 // Definitions of all needed tiles
291
292 // A/B tiles in LDS
293 // With c++20 could simplify to below line.
294 // Currently get error: captured structured bindings are a C++20 extension
295 // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
296 auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
297 auto& a_lds_block = ab_lds_blocks.at(I0{});
298 auto& b_lds_block = ab_lds_blocks.at(I1{});
299
300 // Tile distribution for load from lds
301 constexpr auto a_lds_load_tile_distr =
302 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
303 constexpr auto b_lds_load_tile_distr =
304 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
305
306 // A DRAM tile window for load
307 // A LDS tile window for store
308 // A LDS tile for block GEMM
309 auto a_windows =
310 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
311 auto& a_copy_dram_window = a_windows.at(I0{});
312 auto& a_copy_lds_window = a_windows.at(I1{});
313 auto& a_lds_gemm_window = a_windows.at(I2{});
314
315 // B DRAM tile window for load
316 // B LDS tile window for store
317 // B LDS tile for block GEMM
318 auto b_windows =
319 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
320 auto& b_copy_dram_window = b_windows.at(I0{});
321 auto& b_copy_lds_window = b_windows.at(I1{});
322 auto& b_lds_gemm_window = b_windows.at(I2{});
323
324 // Block GEMM
325 auto block_gemm = BlockGemm();
326 auto c_block_tile = block_gemm.MakeCBlockTile();
327
328 using ABlockTileDistr =
329 decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
330 using BBlockTileDistr =
331 decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
332
333 using ABlockTile =
334 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
335 using BBlockTile =
336 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
337
340
341 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
342 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
343
344 constexpr ADramTileWindowStep a_dram_tile_window_step =
345 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
346 constexpr BDramTileWindowStep b_dram_tile_window_step =
347 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
348
349 // -----------------------------------------------------------------------------------------
350 // Gemm pipeline start
351
352 // prefetch
353 // global read 0
354 // Load tile — during value loading, an elementwise function is executed for each A0,
355 // A1, … AN. The values A0, A1, … AN are read by the same thread.
356 a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
357
358 // Move each A — the enhanced function move_tile_window is executed, which takes a tuple
359 // as input.
360 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
361
362 // Load tile — during value loading, an elementwise function is executed for each B0,
363 // B1, … BN. The values B0, B1, … BN are read by the same thread.
364 b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
365
366 // Move each B — the enhanced function move_tile_window is executed, which takes a tuple
367 // as input.
368 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
369
370 // initialize C
371 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
372
373 // LDS write 0
374 if constexpr(is_a_col_major && !is_a_load_tr_v())
375 {
377 Policy::template MakeShuffledARegTileDistribution<Problem>());
378 transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
379 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
380 }
381 else
382 {
383 Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
384 }
385 if constexpr(is_b_row_major && !is_b_load_tr_v())
386 {
388 Policy::template MakeShuffledBRegTileDistribution<Problem>());
389 transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
390 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
391 }
392 else
393 {
394 Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
395 }
396
397 // Global prefetch [1, PrefetchStages]
398 static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
399 a_block_tiles.at(number<prefetch_idx>{}) =
400 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
401
402 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
403
404 b_block_tiles.at(number<prefetch_idx>{}) =
405 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
406
407 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
408 });
409
410 // main body
411 if constexpr(HasHotLoop)
412 {
413 index_t i = 0;
414 do
415 {
416 static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
418 block_gemm.LocalPrefetch(
419 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
420 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
421
423
424 if constexpr(is_a_col_major && !is_a_load_tr_v())
425 {
427 Policy::template MakeShuffledARegTileDistribution<Problem>());
429 a_shuffle_tmp,
430 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
431 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
432 }
433 else
434 {
436 a_copy_lds_window,
437 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
438 }
439 if constexpr(is_b_row_major && !is_b_load_tr_v())
440 {
442 Policy::template MakeShuffledBRegTileDistribution<Problem>());
444 b_shuffle_tmp,
445 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
446 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
447 }
448 else
449 {
451 b_copy_lds_window,
452 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
453 }
454
455 a_block_tiles.at(number<prefetch_idx>{}) =
456 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
457 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
458
459 b_block_tiles.at(number<prefetch_idx>{}) =
460 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
461
462 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
463 });
464
465 i += PrefetchStages;
466 } while(i < (num_loop - PrefetchStages));
467 }
468
469 auto HotLoopTail = [&](auto tail_num) {
470 static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
472
473 block_gemm.LocalPrefetch(
474 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
475 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
476
478
479 if constexpr(is_a_col_major && !is_a_load_tr_v())
480 {
482 Policy::template MakeShuffledARegTileDistribution<Problem>());
483 transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
484 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
485 }
486 else
487 {
488 Base::LocalPrefill(a_copy_lds_window,
489 a_block_tiles.get(number<prefetch_idx>{}));
490 }
491 if constexpr(is_b_row_major && !is_b_load_tr_v())
492 {
494 Policy::template MakeShuffledBRegTileDistribution<Problem>());
495 transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
496 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
497 }
498 else
499 {
500 Base::LocalPrefill(b_copy_lds_window,
501 b_block_tiles.get(number<prefetch_idx>{}));
502 }
503 });
504
506 block_gemm.LocalPrefetch(
507 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
508 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
509 };
510
511 if constexpr(TailNum == TailNumber::One)
512 {
514 block_gemm.LocalPrefetch(
515 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
516 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
517 }
518 else if constexpr(TailNum == TailNumber::Two)
519 {
520 HotLoopTail(number<2>{});
521 }
522 else if constexpr(TailNum == TailNumber::Three)
523 {
524 HotLoopTail(number<3>{});
525 }
526 else if constexpr(TailNum == TailNumber::Four)
527 {
528 HotLoopTail(number<4>{});
529 }
530 else if constexpr(TailNum == TailNumber::Five)
531 {
532 HotLoopTail(number<5>{});
533 }
534 else if constexpr(TailNum == TailNumber::Six)
535 {
536 HotLoopTail(number<6>{});
537 }
538 else if constexpr(TailNum == TailNumber::Seven)
539 {
540 HotLoopTail(number<7>{});
541 }
542 else if constexpr(TailNum == TailNumber::Full)
543 {
544 HotLoopTail(number<PrefetchStages>{});
545 }
546
547 return c_block_tile;
548 }
549 };
550
551 template <>
553 {
555
556 template <bool HasHotLoop,
558 typename AsDramBlockWindowTmp,
559 typename BsDramBlockWindowTmp,
560 typename AElementFunction,
561 typename BElementFunction,
562 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
564 bool>* = nullptr>
565 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
566 const AElementFunction& a_element_func,
567 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
568 const BElementFunction& b_element_func,
569 index_t num_loop,
570 void* p_smem) const
571 {
572 using ADramBlockWindowTmp =
573 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
574 using BDramBlockWindowTmp =
575 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
576
577 static_assert(
578 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
579 std::is_same_v<BDataType,
581 "A/B Dram block window should have the same data type as appropriate "
582 "([A|B]DataType) defined in Problem definition!");
583
584 constexpr bool is_a_col_major =
585 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
586 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
587
588 static_assert(is_a_col_major
589 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
590 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
591 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
592 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
593 "A block window has incorrect lengths for defined ALayout!");
594 static_assert(is_b_row_major
595 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
596 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
597 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
598 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
599 "B block window has incorrect lengths for defined BLayout!");
600
601 // ------------------------------------------------------------------------------------
602 // Definitions of all needed tiles
603
604 // A/B tiles in LDS
605 // With c++20 could simplify to below line.
606 // Currently get error: captured structured bindings are a C++20 extension
607 // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
608 auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
609 auto& a_lds_block = ab_lds_blocks.at(I0{});
610 auto& b_lds_block = ab_lds_blocks.at(I1{});
611
612 // Tile distribution for load from lds
613 constexpr auto a_lds_load_tile_distr =
614 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
615 constexpr auto b_lds_load_tile_distr =
616 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
617
618 // A DRAM tile window for load
619 // A LDS tile window for store
620 // A LDS tile for block GEMM
621 auto a_windows =
622 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
623 auto& a_copy_dram_window = a_windows.at(I0{});
624 auto& a_copy_lds_window = a_windows.at(I1{});
625 auto& a_lds_gemm_window = a_windows.at(I2{});
626
627 // B DRAM tile window for load
628 // B LDS tile window for store
629 // B LDS tile for block GEMM
630 auto b_windows =
631 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
632 auto& b_copy_dram_window = b_windows.at(I0{});
633 auto& b_copy_lds_window = b_windows.at(I1{});
634 auto& b_lds_gemm_window = b_windows.at(I2{});
635
636 // Block GEMM
637 auto block_gemm = BlockGemm();
638 auto c_block_tile = block_gemm.MakeCBlockTile();
639
640 using ABlockTileDistr =
641 decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
642 using BBlockTileDistr =
643 decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
644
645 using ABlockTile =
646 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
647 using BBlockTile =
648 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
649
652
653 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
654 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
655
656 constexpr ADramTileWindowStep a_dram_tile_window_step =
657 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
658 constexpr BDramTileWindowStep b_dram_tile_window_step =
659 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
660 // -----------------------------------------------------------------------------------------
661 // Gemm pipeline start
662
663 // prefetch
664 // global read 0
665
666 // Load tile — during value loading, an elementwise function is executed for each A0,
667 // A1, … AN. The values A0, A1, … AN are read by the same thread.
668 a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
669
670 // Move each A — the enhanced function move_tile_window is executed, which takes a tuple
671 // as input.
672 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
673
674 // Load tile — during value loading, an elementwise function is executed for each B0,
675 // B1, … BN. The values B0, B1, … BN are read by the same thread.
676 b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
677
678 // Move each B — the enhanced function move_tile_window is executed, which takes a tuple
679 // as input.
680 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
681
682 // initialize C
683 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
684
685 // LDS write 0
686 if constexpr(is_a_col_major && !is_a_load_tr_v())
687 {
689 Policy::template MakeShuffledARegTileDistribution<Problem>());
690 transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
691 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
692 }
693 else
694 {
695 Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
696 }
697 if constexpr(is_b_row_major && !is_b_load_tr_v())
698 {
700 Policy::template MakeShuffledBRegTileDistribution<Problem>());
701 transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
702 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
703 }
704 else
705 {
706 Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
707 }
708
709 // Global prefetch [1, PrefetchStages]
710 static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
711 a_block_tiles.at(number<prefetch_idx>{}) =
712 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
713
714 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
715
716 b_block_tiles.at(number<prefetch_idx>{}) =
717 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
718
719 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
720 });
721
722 // main body
723 if constexpr(HasHotLoop)
724 {
725 index_t i = 0;
726 do
727 {
728 static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
730 block_gemm(c_block_tile,
731 a_lds_gemm_window,
732 b_lds_gemm_window,
735 // no second block_sync_lds because it's interwave
736
737 if constexpr(is_a_col_major && !is_a_load_tr_v())
738 {
740 Policy::template MakeShuffledARegTileDistribution<Problem>());
742 a_shuffle_tmp,
743 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
744 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
745 }
746 else
747 {
749 a_copy_lds_window,
750 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
751 }
752 if constexpr(is_b_row_major && !is_b_load_tr_v())
753 {
755 Policy::template MakeShuffledBRegTileDistribution<Problem>());
757 b_shuffle_tmp,
758 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
759 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
760 }
761 else
762 {
764 b_copy_lds_window,
765 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
766 }
767
768 a_block_tiles.at(number<prefetch_idx>{}) =
769 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
770
771 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
772
773 b_block_tiles.at(number<prefetch_idx>{}) =
774 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
775
776 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
777 });
778
779 i += PrefetchStages;
780 } while(i < (num_loop - PrefetchStages));
781 }
782
783 auto HotLoopTail = [&](auto tail_num) {
784 static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
786 block_gemm(c_block_tile,
787 a_lds_gemm_window,
788 b_lds_gemm_window,
791 // no second block_sync_lds because it's interwave
792
793 if constexpr(is_a_col_major && !is_a_load_tr_v())
794 {
796 Policy::template MakeShuffledARegTileDistribution<Problem>());
797 transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
798 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
799 }
800 else
801 {
802 Base::LocalPrefill(a_copy_lds_window,
803 a_block_tiles.get(number<prefetch_idx>{}));
804 }
805 if constexpr(is_b_row_major && !is_b_load_tr_v())
806 {
808 Policy::template MakeShuffledBRegTileDistribution<Problem>());
809 transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
810 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
811 }
812 else
813 {
814 Base::LocalPrefill(b_copy_lds_window,
815 b_block_tiles.get(number<prefetch_idx>{}));
816 }
817 });
818
820 block_gemm(c_block_tile,
821 a_lds_gemm_window,
822 b_lds_gemm_window,
825 };
826
827 if constexpr(TailNum == TailNumber::One)
828 {
830 block_gemm(c_block_tile,
831 a_lds_gemm_window,
832 b_lds_gemm_window,
835 }
836 else if constexpr(TailNum == TailNumber::Two)
837 {
838 HotLoopTail(number<2>{});
839 }
840 else if constexpr(TailNum == TailNumber::Three)
841 {
842 HotLoopTail(number<3>{});
843 }
844 else if constexpr(TailNum == TailNumber::Four)
845 {
846 HotLoopTail(number<4>{});
847 }
848 else if constexpr(TailNum == TailNumber::Five)
849 {
850 HotLoopTail(number<5>{});
851 }
852 else if constexpr(TailNum == TailNumber::Six)
853 {
854 HotLoopTail(number<6>{});
855 }
856 else if constexpr(TailNum == TailNumber::Seven)
857 {
858 HotLoopTail(number<7>{});
859 }
860 else if constexpr(TailNum == TailNumber::Full)
861 {
862 HotLoopTail(number<PrefetchStages>{});
863 }
864
865 return c_block_tile;
866 }
867 };
868
869 template <typename AsDramBlockWindowTmp,
870 typename BsDramBlockWindowTmp,
871 typename AElementFunction,
872 typename BElementFunction,
873 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
875 bool>* = nullptr>
876 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
877 const AElementFunction& a_element_func,
878 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
879 const BElementFunction& b_element_func,
880 index_t num_loop,
881 void* p_smem) const
882 {
883 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
884 a_dram_block_window_tmp,
885 a_element_func,
886 b_dram_block_window_tmp,
887 b_element_func,
888 num_loop,
889 p_smem);
890 }
891
892 template <typename AsDramBlockWindowTmp,
893 typename BsDramBlockWindowTmp,
894 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
896 bool>* = nullptr>
897 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
898 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
899 index_t num_loop,
900 bool has_hot_loop,
901 TailNumber tail_number,
902 void* p_smem) const
903 {
904 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
905 constexpr bool hot_loop = hot_loop_.value;
906 constexpr auto tail_num = tail_num_.value;
907 constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
908 return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
909 a_dram_block_window_tmp,
911 b_dram_block_window_tmp,
913 num_loop,
914 p_smem);
915 };
916 return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
917 }
918
919 template <typename AsDramBlockWindowTmp,
920 typename BsDramBlockWindowTmp,
921 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
923 bool>* = nullptr>
924 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
925 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
926 index_t num_loop,
927 void* p_smem) const
928 {
929 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
930 a_dram_block_window_tmp,
931 [](auto& e, const ADataType& a) { e = a; },
932 b_dram_block_window_tmp,
933 [](auto& e, const ADataType& a) { e = a; },
934 num_loop,
935 p_smem);
936 }
937
938 template <typename ADramBlockWindowTmp,
939 typename BDramBlockWindowTmp,
940 typename AElementFunction,
941 typename BElementFunction,
942 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
944 bool>* = nullptr>
945 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
946 const AElementFunction& a_element_func,
947 const BDramBlockWindowTmp& b_dram_block_window_tmp,
948 const BElementFunction& b_element_func,
949 index_t num_loop,
950 void* p_smem) const
951 {
952 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
953 a_element_func,
954 ck_tile::make_tuple(b_dram_block_window_tmp),
955 b_element_func,
956 num_loop,
957 p_smem);
958 }
959
960 template <typename ADramBlockWindowTmp,
961 typename BDramBlockWindowTmp,
962 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
964 bool>* = nullptr>
965 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
966 const BDramBlockWindowTmp& b_dram_block_window_tmp,
967 index_t num_loop,
968 bool has_hot_loop,
969 TailNumber tail_number,
970 void* p_smem) const
971 {
972 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
973 ck_tile::make_tuple(b_dram_block_window_tmp),
974 num_loop,
975 has_hot_loop,
976 tail_number,
977 p_smem);
978 }
979
980 template <typename ADramBlockWindowTmp,
981 typename BDramBlockWindowTmp,
982 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
984 bool>* = nullptr>
985 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
986 const BDramBlockWindowTmp& b_dram_block_window_tmp,
987 index_t num_loop,
988 void* p_smem) const
989 {
990 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
991 ck_tile::make_tuple(b_dram_block_window_tmp),
992 num_loop,
993 p_smem);
994 }
995};
996
997} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE)
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ One
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:27
@ Seven
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:33
@ Four
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:30
@ Two
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:28
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
@ Three
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:29
@ Five
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:31
@ Six
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:32
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
ck_tile::element_wise::PassThrough PassThrough
Definition grouped_convolution_utils.hpp:47
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
typename impl::tuple_array_impl< T, N >::type tuple_array
Definition tile/core/container/tuple.hpp:28
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_mem.hpp:19
static constexpr index_t WgpPerCU
Definition gemm_pipeline_ag_bg_cr_mem.hpp:39
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:32
static constexpr bool UsePersistentKernel
Definition gemm_pipeline_ag_bg_cr_mem.hpp:53
remove_cvref_t< typename Problem::BDataType > BDataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:21
static constexpr index_t BPackedSize
Definition gemm_pipeline_ag_bg_cr_mem.hpp:26
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_mem.hpp:60
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_mem.hpp:31
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:46
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_mem.hpp:52
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:29
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_mem.hpp:98
static constexpr index_t FullMemBandPrefetchStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:41
static constexpr index_t LocalPrefillStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:51
remove_cvref_t< typename Problem::ADataType > ADataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:20
static constexpr index_t MinMemInFlyBytes
Definition gemm_pipeline_ag_bg_cr_mem.hpp:37
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:34
static constexpr index_t APackedSize
Definition gemm_pipeline_ag_bg_cr_mem.hpp:24
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_mem.hpp:22
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:33
static CK_TILE_HOST_DEVICE constexpr bool BlockHasHotloop(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_mem.hpp:55
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_mem.hpp:554
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:565
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:253
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_mem.hpp:242
Definition gemm_pipeline_ag_bg_cr_mem.hpp:236
Definition gemm_pipeline_ag_bg_cr_mem.hpp:154
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_mem.hpp:164
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:196
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:192
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:182
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:159
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_mem.hpp:207
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:174
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_mem.hpp:205
BaseGemmPipelineAgBgCrMem< Problem > Base
Definition gemm_pipeline_ag_bg_cr_mem.hpp:155
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:199
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_mem.hpp:163
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_mem.hpp:211
GemmPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_pipeline_ag_bg_cr_mem.hpp:156
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:198
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:183
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_mem.hpp:184
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:187
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_mem.hpp:206
static constexpr auto is_b_load_tr_v
Definition gemm_pipeline_ag_bg_cr_mem.hpp:215
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_mem.hpp:166
number< 2 > I2
Definition gemm_pipeline_ag_bg_cr_mem.hpp:180
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_mem.hpp:171
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_mem.hpp:168
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:46
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_mem.hpp:170
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_mem.hpp:212
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_mem.hpp:176
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:173
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:876
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:897
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:924
number< 0 > I0
Definition gemm_pipeline_ag_bg_cr_mem.hpp:178
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:945
static constexpr auto is_a_load_tr_v
Definition gemm_pipeline_ag_bg_cr_mem.hpp:214
number< 1 > I1
Definition gemm_pipeline_ag_bg_cr_mem.hpp:179
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_mem.hpp:162
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_mem.hpp:202
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_mem.hpp:167
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:229
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_mem.hpp:210
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_mem.hpp:201
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:985
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_mem.hpp:203
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_mem.hpp:965
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:160
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_ag_bg_cr_mem.hpp:217
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_mem.hpp:158
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43