cshuffle_epilogue.hpp Source File

cshuffle_epilogue.hpp Source File#

Composable Kernel: cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck_tile/core.hpp"
12
13#include <type_traits>
14
15namespace ck_tile {
16
17template <typename AsDataType_,
18 typename BsDataType_,
19 typename DsDataType_,
20 typename AccDataType_,
21 typename ODataType_,
22 typename DsLayout_,
23 typename ELayout_,
24 typename CDElementwise_,
25 index_t kM_,
26 index_t kN_,
27 index_t MWave_,
28 index_t NWave_,
29 index_t MPerXdl_,
30 index_t NPerXdl_,
31 index_t KPerXdl_,
32 bool isCTransposed_,
33 memory_operation_enum MemoryOperation_,
34 index_t kNumWaveGroups_ = 1,
35 bool FixedVectorSize_ = false,
36 index_t VectorSizeC_ = 1,
37 bool TiledMMAPermuteN_ = false,
38 index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
40{
49 static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
50 static constexpr index_t kMPerBlock = kM_;
51 static constexpr index_t kNPerBlock = kN_;
52 static constexpr index_t MWave = MWave_;
53 static constexpr index_t NWave = NWave_;
54 static constexpr index_t MPerXdl = MPerXdl_;
55 static constexpr index_t NPerXdl = NPerXdl_;
56 static constexpr index_t KPerXdl = KPerXdl_;
57 static constexpr index_t isCTransposed = isCTransposed_;
58 static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
59 static constexpr bool FixedVectorSize = FixedVectorSize_;
60 static constexpr index_t VectorSizeC = VectorSizeC_;
61 static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
62 static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
63 static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
64 static constexpr index_t NumDTensor = DsDataType::size();
65
66 static_assert(NumDTensor == DsLayout::size(),
67 "The size of DsDataType and DsLayout should be the same");
68};
69
70template <typename Problem_, typename Policy_ = void>
72{
80
83
84 using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
87
88 using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
91
94
95 using ATypeToUse =
96 std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
97 // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
98 using BTypeToUse =
99 std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
102 static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
103 static constexpr index_t kBlockSize = Problem::kBlockSize;
104 static constexpr index_t kMPerBlock = Problem::kMPerBlock;
105 static constexpr index_t kNPerBlock = Problem::kNPerBlock;
106 static constexpr index_t MWave = Problem::MWave;
107 static constexpr index_t NWave = Problem::NWave;
108 static constexpr index_t MPerXdl = Problem::MPerXdl;
109 static constexpr index_t NPerXdl = Problem::NPerXdl;
110 static constexpr index_t KPerXdl = Problem::KPerXdl;
111 static constexpr index_t isCTransposed = Problem::isCTransposed;
112 static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
113 static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
114 static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
115 static constexpr index_t VectorSizeC = Problem::VectorSizeC;
116 static constexpr index_t MPerIteration = MPerXdl * MWave;
117 static constexpr index_t NPerIteration = NPerXdl * NWave;
118 static constexpr index_t NumDTensor = Problem::NumDTensor;
119 static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
120 static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
121
123
125
126 static_assert(NumDTensor == DsLayout::size(),
127 "The size of DsDataType and DsLayout should be the same");
128
129 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
130 {
131 // clang-format off
132 return concat('_', "CShuffleEpilogue",
133 concat('x', MWave, NWave),
136 isCTransposed ? "CTransposed" : "CNotTransposed",
138 // clang-format on
139 }
140
152 {
153 if constexpr(FixedVectorSize)
154 {
155 return VectorSizeC;
156 }
157 constexpr index_t max_vector_size = 16;
158 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
159 {
160 return std::min(static_cast<int>(NPerIteration),
161 static_cast<int>(max_vector_size / sizeof(ODataType)));
162 }
163 else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
164 {
165 return std::min(static_cast<int>(MPerIteration),
166 static_cast<int>(max_vector_size / sizeof(ODataType)));
167 }
168 else
169 {
170 static_assert(false, "Unsupported ELayout!");
171 }
172 }
173
179 template <index_t I>
181 {
182 constexpr index_t max_vector_size = 16;
183 using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
184 using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
185 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
186 {
187 return std::min(static_cast<int>(NPerIteration),
188 static_cast<int>(max_vector_size / sizeof(DiDataType)));
189 }
190 else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
191 {
192 return std::min(static_cast<int>(MPerIteration),
193 static_cast<int>(max_vector_size / sizeof(DiDataType)));
194 }
195 else
196 {
197 static_assert(false, "Unsupported DLayout!");
198 }
199 return max_vector_size / sizeof(DiDataType);
200 }
201
209 static constexpr auto shuffle_tile_tuple = [] {
210 constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
211 if constexpr(elem_per_thread >= GetVectorSizeC())
212 {
213 return std::make_tuple(1, 1);
214 }
215 else
216 {
217 constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
218 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
219 {
220 static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
221 (kMPerBlock % num_xdl_shuffles == 0),
222 "kMPerBlock must be divisible by MPerXdl*MWave and "
223 "num_xdl_shuffles for CShuffleEpilogue");
224 return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
225 }
226 else
227 {
228 static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
229 (kNPerBlock % num_xdl_shuffles == 0),
230 "kNPerBlock must be divisible by NPerXdl*NWave and "
231 "num_xdl_shuffles for CShuffleEpilogue");
232 return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
233 }
234 }
235 }();
236 static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
239
240 static constexpr auto MNPerIterationShuffle = [] {
241 constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
242 constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
243 if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
244 return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
245 else
246 return std::make_tuple(m_val, n_val);
247 }();
248 static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
249 static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
250
254 MPerXdl,
255 NPerXdl,
256 KPerXdl,
258
259 using CWarpDstr = typename WG::CWarpDstr;
260 using CWarpTensor = typename WG::CWarpTensor;
261 using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
265
266 template <typename Problem>
268 {
269 // N is contiguous dimension
270 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
271 {
275 }
276 // M is contiguous dimension
277 else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
278 {
282 }
283 else
284 {
285 static_assert(false, "Unsupported ELayout!");
286 }
287 }
288
290 {
291 constexpr auto block_outer_dstr_encoding = [] {
292 if constexpr(BlockedXDLN_PerWarp == 1)
293 {
301 }
302 else
303 {
304 constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
305 // BlockedLayout
314 }
315 }();
316 constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
317 block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
318
319 return block_dstr_encoding;
320 }
321
323 {
325 }
326
327 template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
328 CK_TILE_DEVICE void
329 scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
330 {
331 // Check if scales are EmptyScale first (no scaling needed)
332 if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
333 {
334 // No scaling needed - this is a no-op
335 }
336 // Check if scales are scalar AccDataType
337 else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
338 std::is_same_v<ScaleN, AccDataType>)
339 {
340 // Handle scalar scales
341 const AccDataType scale_m = scale_m_window;
342 const AccDataType scale_n = scale_n_window;
343 tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
344 lds_tile);
345 }
346 // Otherwise, assume they are tile windows that can be loaded
347 else
348 {
349 // Load tiles
350 const auto scale_m_tile = load_tile(scale_m_window);
351 const auto scale_n_tile = load_tile(scale_n_window);
352
353 // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
355 element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
356
357 // Move scale windows
358 constexpr index_t num_access = SFC::get_num_of_access();
359 if constexpr(iAccess != num_access - 1)
360 {
361 constexpr auto step = SFC::get_forward_step(number<iAccess>{});
362
363 move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
364 move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
365 }
366 }
367 }
368
369 template <index_t iAccess, typename OAccTile, typename LdsTile>
370 CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
371 {
372 constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
373
374 constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
375 constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
376 constexpr auto c_warp_y_lengths =
377 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
378 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
379
380 lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
383 c_warp_y_index_zeros),
385 c_warp_y_lengths));
386 }
387
388 template <typename LdsTile, typename InLdsWindow>
389 CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
390 {
391 const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
392
393 store_tile(in_lds_window, c_warptile_in_tensor_casted);
394 }
395
396 template <typename DramWindows, typename COutTensor>
397 CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
398 {
399 const auto ds_tensor = generate_tuple(
400 [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
401
402 const auto c_ds_tiles = concat_tuple_of_reference(
403 tie(c_out_tensor, c_out_tensor),
404 generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
406
408 }
409
410 template <typename OutDramWindow, typename COutTensor>
411 CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
412 const COutTensor& c_out_tensor)
413 {
415 {
416 store_tile(out_dram_window, c_out_tensor);
417 }
418 else
419 {
420 update_tile(out_dram_window, c_out_tensor);
421 }
422 }
423
427 template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
428 CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
429 {
430 constexpr index_t num_access = SFC::get_num_of_access();
431 if constexpr(iAccess != num_access - 1)
432 {
433 constexpr auto step = SFC::get_forward_step(number<iAccess>{});
434
435 // move the output dram window
436 move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
437
438 // move windows for each of the D matrices (inputs for element-wise)
439 static_for<0, NumDTensor, 1>{}([&](auto idx) {
440 move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
441 });
442 }
443 }
444
445 // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
447 {
448 };
449
450 template <typename, typename = void>
452 {
453 using DataType = float;
454 };
455
456 template <typename T>
457 struct ScaleDataType<T, std::void_t<typename T::DataType>>
458 {
459 using DataType = typename T::DataType;
460 };
461
462 template <typename ODramWindow,
463 typename OAccTile,
464 typename DsDramWindows,
465 typename ScaleM = EmptyScale,
466 typename ScaleN = EmptyScale,
467 int EnablePermuateN_ = TiledMMAPermuteN,
468 std::enable_if_t<EnablePermuateN_, int> = 0>
469 CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
470 const OAccTile& o_acc_tile,
471 const DsDramWindows& ds_dram_windows,
472 void* /* p_smem */,
473 const ScaleM& scale_m = {},
474 const ScaleN& scale_n = {})
475 {
476 static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
477
478 static_assert(MPerXdl % RowsPerLane == 0,
479 "CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
480 constexpr int kM0 = MWave;
481 constexpr int kM2 = RowsPerLane;
482 constexpr int kM1 = MPerXdl / kM2;
483
484 constexpr int kN0 = NWave;
485 constexpr int kN1 = NPerXdl;
486 constexpr int kN2 = NRepeat;
487
488 using IntrThreadShuffleEncode =
489 tile_distribution_encoding<sequence<>,
490 tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
491 tuple<sequence<1, 2>, sequence<1, 2>>,
492 tuple<sequence<0, 0>, sequence<1, 1>>,
493 sequence<1, 2>,
494 sequence<2, 2>>;
495 constexpr auto dram_tile_distribution =
496 make_static_tile_distribution(IntrThreadShuffleEncode{});
497
498 auto d_dram_windows = generate_tuple(
499 [&](auto idx) {
500 return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
501 },
503
504 constexpr auto c_warp_y_lengths =
505 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
506 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
507
508 auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
509 auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
510
511 // Optional scales (must share the same distribution to match per-thread indexing)
512 constexpr bool has_scales =
513 !std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
514 constexpr bool has_scalar_scales =
515 std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
516
517 // Tiles to hold row/col scales when present
518 using SMType = typename ScaleDataType<ScaleM>::DataType;
519 using SNType = typename ScaleDataType<ScaleN>::DataType;
520
521 auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
522 auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
523
524 // Build windows only if non-scalar scales are provided
525 auto scale_m_window = [&]() {
526 if constexpr(has_scales && !has_scalar_scales)
527 {
528 return make_tile_window(scale_m, dram_tile_distribution);
529 }
530 else
531 {
532 return EmptyScale{};
533 }
534 }();
535 auto scale_n_window = [&]() {
536 if constexpr(has_scales && !has_scalar_scales)
537 {
538 return make_tile_window(scale_n, dram_tile_distribution);
539 }
540 else
541 {
542 return EmptyScale{};
543 }
544 }();
545
546 static_for<0, MRepeat, 1>{}([&](auto mIter) {
547 // Slice accumulators for this M repeat into the permuted layout
548 shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
549 merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
550 merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
551
552 // If non-scalar scales provided, load them with identical distribution
553 if constexpr(has_scales && !has_scalar_scales)
554 {
555 sm_tile = load_tile(scale_m_window); // row scales in permuted layout
556 sn_tile = load_tile(scale_n_window); // col scales in permuted layout
557 }
558
559 // Pack 4 “rows per lane” as you already do
560 static_for<0, NRepeat, 1>{}([&](auto n_idx) {
561 // source indices in shuffle_acc: (n_idx * product(Y) + row)
562 const index_t plane = c_warp_y_lengths.product();
563
564 // local lambda to fuse scale (if present) and convert
565 static_for<0, kM2, 1>{}([&](auto m_lane) {
566 const int src = n_idx * plane + m_lane; // source row in this N-plane
567 const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
568 AccDataType v = shuffle_acc.get_thread_buffer()[src];
569
570 if constexpr(has_scalar_scales)
571 {
572 v = static_cast<AccDataType>(v * scale_m * scale_n);
573 }
574 else if constexpr(has_scales && !has_scalar_scales)
575 {
576 const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
577 const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
578 v = static_cast<AccDataType>(v * sm * sn);
579 }
580
581 c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
582 });
583 });
584
585 // store/update
587 {
588 store_tile(out_dram_window, c_out_tensor);
589 }
590 else
591 {
592 update_tile(out_dram_window, c_out_tensor);
593 }
594
595 // advance output (and any D-tensors) by one MPerXdl*MWave chunk
596 move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
597 static_for<0, NumDTensor, 1>{}([&](auto idx) {
598 move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
599 });
600 });
601 }
602
603 template <typename ODramWindow,
604 typename OAccTile,
605 typename DsDramWindows,
606 typename ScaleM = EmptyScale,
607 typename ScaleN = EmptyScale,
608 int EnablePermuateN_ = TiledMMAPermuteN,
609 std::enable_if_t<!EnablePermuateN_, int> = 0>
610 CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
611 const OAccTile& o_acc_tile,
612 const DsDramWindows& ds_dram_windows,
613 void* p_smem,
614 const ScaleM& scale_m = {},
615 const ScaleN& scale_n = {})
616 {
617 constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
618
619 auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
620
621 constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
623 static_cast<ODataType*>(p_smem), lds_block_desc);
624
625 auto in_lds_window = make_tile_window(
626 o_lds_block,
628 {0, 0},
629 LdsTileDistr);
630
631 auto out_lds_window = make_tile_window(
632 o_lds_block,
634 {0, 0});
635
636 constexpr index_t num_access = SFC::get_num_of_access();
637
638 static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
639 "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
640
641 using TileEncodingPattern =
642 tile_distribution_encoding_pattern_2d<kBlockSize,
647 Problem::kNumWaveGroups>;
648 constexpr auto dram_tile_distribution =
649 TileEncodingPattern::make_2d_static_tile_distribution();
650
651 auto d_dram_windows = generate_tuple(
652 [&](auto idx) {
653 return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
654 },
656
657 constexpr bool has_scales =
658 !std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
659 constexpr bool has_scalar_scales =
660 std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
661 auto scale_m_window = [&]() {
662 if constexpr(has_scalar_scales)
663 {
664 return scale_m;
665 }
666 else if constexpr(has_scales)
667 {
668 return make_tile_window(scale_m, lds_tile.get_tile_distribution());
669 }
670 else
671 {
672 return EmptyScale{};
673 }
674 }();
675 auto scale_n_window = [&]() {
676 if constexpr(has_scalar_scales)
677 {
678 return scale_n;
679 }
680 else if constexpr(has_scales)
681 {
682 return make_tile_window(scale_n, lds_tile.get_tile_distribution());
683 }
684 else
685 {
686 return EmptyScale{};
687 }
688 }();
689
690 static_for<0, num_access, 1>{}([&](auto iAccess) {
692 slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
693
694 if constexpr(has_scales)
695 {
696 scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
697 }
698
699 cast_lds_tile(lds_tile, in_lds_window);
701
702 auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
703
704 apply_d_tensors(d_dram_windows, c_out_tensor);
705 store_to_dram(out_dram_window, c_out_tensor);
706 move_windows<iAccess>(out_dram_window, d_dram_windows);
707 });
708 }
709};
710} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition tile/core/container/tuple.hpp:376
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
memory_operation_enum
Definition arch.hpp:56
@ set
Definition arch.hpp:57
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition tile/core/container/tuple.hpp:443
std::string mem_op_string()
Definition utils.hpp:42
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition tile_elementwise.hpp:71
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
@ thread_raked
Thread raked pattern.
Definition static_encoding_pattern.hpp:94
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:435
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
STL namespace.
Definition cshuffle_epilogue.hpp:447
typename T::DataType DataType
Definition cshuffle_epilogue.hpp:459
Definition cshuffle_epilogue.hpp:452
float DataType
Definition cshuffle_epilogue.hpp:453
static constexpr index_t kBlockSize
Definition cshuffle_epilogue.hpp:103
CK_TILE_DEVICE void scale_tile(LdsTile &lds_tile, ScaleM &scale_m_window, ScaleN &scale_n_window)
Definition cshuffle_epilogue.hpp:329
static constexpr index_t NRepeat
Definition cshuffle_epilogue.hpp:120
CK_TILE_DEVICE void slice_acc_tile(const OAccTile &o_acc_tile, LdsTile &lds_tile)
Definition cshuffle_epilogue.hpp:370
static constexpr index_t MRepeat
Definition cshuffle_epilogue.hpp:119
typename WG::CWarpTensor CWarpTensor
Definition cshuffle_epilogue.hpp:260
typename WG::CWarpDstrEncoding CWarpDstrEncoding
Definition cshuffle_epilogue.hpp:261
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition cshuffle_epilogue.hpp:74
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition cshuffle_epilogue.hpp:92
remove_cvref_t< Problem_ > Problem
Definition cshuffle_epilogue.hpp:73
static constexpr index_t MPerXdl
Definition cshuffle_epilogue.hpp:108
static constexpr bool FixedVectorSize
Definition cshuffle_epilogue.hpp:112
remove_cvref_t< typename Problem::ODataType > ODataType
Definition cshuffle_epilogue.hpp:77
CK_TILE_DEVICE void store_to_dram(OutDramWindow &out_dram_window, const COutTensor &c_out_tensor)
Definition cshuffle_epilogue.hpp:411
static CK_TILE_HOST_DEVICE constexpr index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition cshuffle_epilogue.hpp:151
static constexpr bool ADataTypeIsTuple
Definition cshuffle_epilogue.hpp:81
static constexpr index_t kNPerBlock
Definition cshuffle_epilogue.hpp:105
static constexpr index_t BlockedXDLN_PerWarp
Definition cshuffle_epilogue.hpp:114
remove_cvref_t< typename Problem::ELayout > ELayout
Definition cshuffle_epilogue.hpp:100
static constexpr memory_operation_enum MemoryOperation
Definition cshuffle_epilogue.hpp:102
static constexpr bool TiledMMAPermuteN
Definition cshuffle_epilogue.hpp:113
static constexpr bool BDataTypeIsTuple
Definition cshuffle_epilogue.hpp:82
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition cshuffle_epilogue.hpp:79
static constexpr index_t MPerIteration
Definition cshuffle_epilogue.hpp:116
static constexpr auto MNPerIterationShuffle
Definition cshuffle_epilogue.hpp:240
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition cshuffle_epilogue.hpp:322
static constexpr index_t isCTransposed
Definition cshuffle_epilogue.hpp:111
CDElementwise elfunc_
Definition cshuffle_epilogue.hpp:122
CK_TILE_DEVICE void apply_d_tensors(DramWindows &d_dram_windows, COutTensor &c_out_tensor)
Definition cshuffle_epilogue.hpp:397
static constexpr index_t MWave
Definition cshuffle_epilogue.hpp:106
static constexpr index_t VectorSizeC
Definition cshuffle_epilogue.hpp:115
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition cshuffle_epilogue.hpp:78
CK_TILE_DEVICE void move_windows(OutDramWindow &out_dram_window, DDramWindows &d_dram_windows)
Move both the output and D tensors windows for the next access.
Definition cshuffle_epilogue.hpp:428
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsBlockDescriptor()
Definition cshuffle_epilogue.hpp:267
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition cshuffle_epilogue.hpp:101
static CK_TILE_HOST const std::string GetName()
Definition cshuffle_epilogue.hpp:129
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType > ATypeToUse
Definition cshuffle_epilogue.hpp:95
static CK_TILE_DEVICE constexpr auto MakeLdsDistributionEncode()
Definition cshuffle_epilogue.hpp:289
static constexpr index_t NPerIterationShuffle
Definition cshuffle_epilogue.hpp:249
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition cshuffle_epilogue.hpp:76
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition cshuffle_epilogue.hpp:469
static constexpr index_t NumDTensor
Definition cshuffle_epilogue.hpp:118
static constexpr index_t KPerXdl
Definition cshuffle_epilogue.hpp:110
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition cshuffle_epilogue.hpp:610
static constexpr index_t NumMXdlPerWavePerShuffle
Definition cshuffle_epilogue.hpp:236
static constexpr index_t NumNXdlPerWavePerShuffle
Definition cshuffle_epilogue.hpp:237
static CK_TILE_HOST_DEVICE constexpr index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition cshuffle_epilogue.hpp:180
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition cshuffle_epilogue.hpp:75
space_filling_curve< sequence< kMPerBlock, kNPerBlock >, sequence< 0, 1 >, sequence< MPerIterationShuffle, NPerIterationShuffle > > SFC
Definition cshuffle_epilogue.hpp:262
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > > > AsDataTypeTuple
Definition cshuffle_epilogue.hpp:84
static constexpr index_t MPerIterationShuffle
Definition cshuffle_epilogue.hpp:248
CK_TILE_DEVICE void cast_lds_tile(LdsTile &lds_tile, InLdsWindow &in_lds_window)
Definition cshuffle_epilogue.hpp:389
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc=CDElementwise{})
Definition cshuffle_epilogue.hpp:124
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition cshuffle_epilogue.hpp:209
WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition cshuffle_epilogue.hpp:251
static constexpr index_t NWave
Definition cshuffle_epilogue.hpp:107
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition cshuffle_epilogue.hpp:93
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > > > BsDataTypeTuple
Definition cshuffle_epilogue.hpp:88
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition cshuffle_epilogue.hpp:98
static constexpr index_t kMPerBlock
Definition cshuffle_epilogue.hpp:104
static constexpr index_t NPerIteration
Definition cshuffle_epilogue.hpp:117
static constexpr index_t NPerXdl
Definition cshuffle_epilogue.hpp:109
typename WG::CWarpDstr CWarpDstr
Definition cshuffle_epilogue.hpp:259
Definition cshuffle_epilogue.hpp:40
static constexpr index_t kNPerBlock
Definition cshuffle_epilogue.hpp:51
remove_cvref_t< AccDataType_ > AccDataType
Definition cshuffle_epilogue.hpp:43
static constexpr index_t NumDTensor
Definition cshuffle_epilogue.hpp:64
static constexpr index_t MPerXdl
Definition cshuffle_epilogue.hpp:54
static constexpr index_t NPerXdl
Definition cshuffle_epilogue.hpp:55
remove_cvref_t< CDElementwise_ > CDElementwise
Definition cshuffle_epilogue.hpp:48
static constexpr index_t isCTransposed
Definition cshuffle_epilogue.hpp:57
static constexpr bool TiledMMAPermuteN
Definition cshuffle_epilogue.hpp:62
static constexpr index_t kBlockSize
Definition cshuffle_epilogue.hpp:49
static constexpr index_t VectorSizeC
Definition cshuffle_epilogue.hpp:60
static constexpr index_t kNumWaveGroups
Definition cshuffle_epilogue.hpp:63
static constexpr index_t NWave
Definition cshuffle_epilogue.hpp:53
remove_cvref_t< ODataType_ > ODataType
Definition cshuffle_epilogue.hpp:44
static constexpr index_t BlockedXDLN_PerWarp
Definition cshuffle_epilogue.hpp:61
remove_cvref_t< DsDataType_ > DsDataType
Definition cshuffle_epilogue.hpp:45
static constexpr memory_operation_enum MemoryOperation
Definition cshuffle_epilogue.hpp:58
static constexpr index_t KPerXdl
Definition cshuffle_epilogue.hpp:56
remove_cvref_t< BsDataType_ > BsDataType
Definition cshuffle_epilogue.hpp:42
static constexpr index_t kMPerBlock
Definition cshuffle_epilogue.hpp:50
static constexpr bool FixedVectorSize
Definition cshuffle_epilogue.hpp:59
remove_cvref_t< ELayout_ > ELayout
Definition cshuffle_epilogue.hpp:47
static constexpr index_t MWave
Definition cshuffle_epilogue.hpp:52
remove_cvref_t< DsLayout_ > DsLayout
Definition cshuffle_epilogue.hpp:46
remove_cvref_t< AsDataType_ > AsDataType
Definition cshuffle_epilogue.hpp:41
static constexpr value_type value
Definition tile/core/numeric/integral_constant.hpp:16
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:491
Definition tile/core/container/sequence.hpp:49
Definition space_filling_curve.hpp:20
static CK_TILE_HOST_DEVICE constexpr auto get_index(number< AccessIdx1d >)
Definition space_filling_curve.hpp:158
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number< AccessIdx1d >)
Definition space_filling_curve.hpp:70
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192