block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Source File

block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Source File#

Composable Kernel: block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Source File
block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// This pipeline is qkv all located in LDS
14template <typename Problem_,
17{
32
35 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
36 static_assert(kQLoadOnce == Policy::QLoadOnce);
37
38 static constexpr index_t kBlockSize = Problem::kBlockSize;
39
40 static constexpr index_t kM0 = BlockFmhaShape::kM0;
41 static constexpr index_t kN0 = BlockFmhaShape::kN0;
42 static constexpr index_t kK0 = BlockFmhaShape::kK0;
43 static constexpr index_t kN1 = BlockFmhaShape::kN1;
44 static constexpr index_t kK1 = BlockFmhaShape::kK1;
45 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
46 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
47
48 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
49
50 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
51 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
52 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
53 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
54 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
55 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
56 static constexpr auto BiasEnum = Problem::BiasEnum;
57 static constexpr bool kStoreLSE = Problem::kStoreLSE;
58 static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
59 static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
60
61 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
62 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
65
66 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
67 // ... together with tensor distribution. tensor dist should able to overwrite this
68 static constexpr index_t kAlignmentQ =
69 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
70 static constexpr index_t kAlignmentK =
71 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
72 static constexpr index_t kAlignmentV = []() {
73 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
74 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
75 else
76 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
77 }();
78
79 static constexpr index_t kAlignmentOacc =
80 kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
81
82 static constexpr index_t kAlignmentBias =
83 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
84
85 static constexpr index_t kBlockPerCu = []() {
86 if constexpr(Problem::kBlockPerCu != -1)
87 return Problem::kBlockPerCu;
88 else
89 {
90 if constexpr(kQKHeaddim <= 32)
91 {
92 return 2;
93 }
94 else if constexpr(kQKHeaddim <= 64)
95 {
96 return 3;
97 }
98 else if constexpr(kQKHeaddim <= 128)
99 {
101 return 1;
102 else
103 return 2;
104 }
105 else if constexpr(kQKHeaddim <= 256)
106 {
107 return 1;
108 }
109 else
110 {
111 return 1;
112 }
113 }
114 }();
115
116 static constexpr const char* name = "qr_nwarp_sshuffle";
117
119 {
120 return Policy::template GetSmemSize<Problem>();
121 }
122
123 template <typename QDramBlockWindowTmp,
124 typename KDramBlockWindowLengths,
125 typename KPageBlockNavigator,
126 typename VDramBlockWindowLengths,
127 typename VPageBlockNavigator,
128 typename BiasDramBlockWindowTmp,
129 typename LSEaccDramBlockWindowTmp,
130 typename QElementFunction,
131 typename KElementFunction,
132 typename VElementFunction,
133 typename BiasElementFunction,
134 typename LSEaccElementFunction,
135 typename SAccElementFunction,
136 typename PComputeElementFunction,
137 typename OAccElementFunction,
138 typename PositionEncoding,
139 typename AttentionVariantParams,
140 typename BlockIndices>
142 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
143 const QElementFunction& q_element_func,
144 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
145 const KPageBlockNavigator& k_page_block_navigator,
146 const KElementFunction& k_element_func,
147 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
148 const VPageBlockNavigator& v_page_block_navigator,
149 const VElementFunction& v_element_func,
150 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
151 const BiasElementFunction& bias_element_func,
152 LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
153 const LSEaccElementFunction& lse_acc_element_func,
154 const SAccElementFunction& s_acc_element_func,
155 const PComputeElementFunction& p_compute_element_func,
156 const OAccElementFunction& o_acc_element_func,
157 index_t num_splits,
158 index_t i_split,
159 FmhaMask mask,
160 PositionEncoding position_encoding,
161 float scale_s,
162 const AttentionVariant& variant,
163 const AttentionVariantParams& variant_params,
164 const BlockIndices& block_indices,
165 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
166 void* smem_ptr) const
167 {
168 static_assert(
169 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
170 std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
171 std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
172 "wrong!");
173
174 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
176 QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
177 kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
178 kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
179 kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
180 kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
181 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
182 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
183 "wrong!");
184 // Q tile in LDS
185 QDataType* q_lds_ptr =
186 static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
188 q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
189
190 // K tile in LDS
191 KDataType* k_lds_ptr =
192 static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
194 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
195 auto k_lds_window =
197
198 // V tile in LDS
200 reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
201 max(Policy::template GetSmemSizeQ<Problem>(),
202 Policy::template GetSmemSizeK<Problem>())),
203 Policy::template MakeVLdsBlockDescriptor<Problem>());
204 auto v_lds_window = make_tile_window(
205 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
206
207 // S tile in LDS
209 reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
210 max(Policy::template GetSmemSizeQ<Problem>(),
211 Policy::template GetSmemSizeK<Problem>())),
212 Policy::template MakeSLdsBlockDescriptor<Problem>());
213 auto s_write_lds_window = make_tile_window(
214 s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
215 auto s_read_lds_window =
216 make_tile_window(s_lds,
217 Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
218 {0, 0},
219 Policy::template MakeSRegTileDistribution<Problem>());
220
221 // Block GEMM
222 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
223 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
224
225 auto q_dram_window =
226 make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
227 q_dram_block_window_tmp.get_window_lengths(),
228 q_dram_block_window_tmp.get_window_origin(),
229 Policy::template MakeQDramTileDistribution<Problem>());
230
231 // load Q here, will store Q into LDS to maximize throughput
232 auto origin_q = load_tile(q_dram_window);
233
234 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
235 auto s_acc = SaccBlockTileType{};
236
237 // reduction function for softmax
238 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
239 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
240
241 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
242
243 auto o_acc = OaccBlockTileType{};
244
245 // infer Sacc, S, P, M, L, Oacc type
246 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
247
248 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
249 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
250
251 // init M, L
252 auto m = MLBlockTileType{};
253 auto l = MLBlockTileType{};
254
255 clear_tile(o_acc);
257 clear_tile(l);
258
259 const auto q_origin = q_dram_window.get_window_origin();
260 const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
261 q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
262
263 // check early exit if no work to do
264 if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
265 {
266 const index_t logical_num_total_loop =
267 integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
268 if(logical_num_total_loop <= 0)
269 {
270 if constexpr(kStoreLSE)
271 {
272 auto lse_acc =
273 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
274
276
278 {
279 store_tile(lse_acc_dram_window_tmp,
280 tile_elementwise_in(lse_acc_element_func, lse_acc));
281 }
282 }
283
284 // Note: here occ are all cleard, return it
285 // Note: q loaded but no fence, ignore it.
286 return o_acc;
287 }
288 }
289
290 const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
291 const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
292 // make sure the first tile is completely located in page-block (page-block size should be
293 // divisible by kN0)
294 // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
295 // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
296 const index_t aligned_physical_seqlen_k_start =
297 [&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
298 if constexpr(kIsPagedKV)
299 {
300 return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
301 }
302 else
303 {
304 return physical_seqlen_k_start_;
305 }
306 }();
307 const index_t num_total_loop =
308 integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
309
310 auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
311 k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
312
313 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
314 auto bias_dram_window =
315 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
316 bias_dram_block_window_tmp.get_window_lengths(),
317 {bias_origin.at(number<0>{}),
318 logical_seqlen_k_start - (physical_seqlen_k_start -
319 aligned_physical_seqlen_k_start)}, // M/N
320 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
321
322 auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
323 v_dram_block_window_lengths,
324 {0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
325 Policy::template MakeVDramTileDistribution<Problem>());
326
327 // store Q into LDS
328 __builtin_amdgcn_sched_barrier(0);
329 auto q_lds_window_for_store = make_tile_window(
330 q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
331
332 store_tile(q_lds_window_for_store, origin_q);
333 __builtin_amdgcn_sched_barrier(0);
334
335 // load Q from LDS
336 __builtin_amdgcn_sched_barrier(0);
337 auto q_lds_window_for_load =
338 make_tile_window(q_lds,
339 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
340 {0, 0},
341 Policy::template MakeQRegTileDistribution<Problem>());
343 auto q = load_tile(q_lds_window_for_load);
344 __builtin_amdgcn_sched_barrier(0);
345 auto q_tile = tile_elementwise_in(q_element_func, q);
346
347 // prefetch K tile
348 index_t i_total_loops = 0;
349 constexpr index_t k0_loops = kQKHeaddim / kK0;
350 constexpr index_t k1_loops = kN0 / kK1;
351
352 static_assert(2 <= k0_loops);
353 static_assert(1 <= k1_loops);
354
355 auto k_dram_window = make_tile_window(
356 k_dram_block_window,
357 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
358
359 // load the first tile of the first iteration and store to LDS
360 auto k_block_tile = load_tile(k_dram_window);
361 // moving k_dram_window is an in-page-block operation, so there is
362 // no need to invoke k_page_block_navigator.move_tile_window() here.
363 move_tile_window(k_dram_window, {0, kK0});
364 // ensure LDS access by Q is done before the over-writting by K
366 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
367
368 do
369 {
370 // STAGE 1, QK gemm
371 clear_tile(s_acc); // initialize C
372
373 // load the second tile of the first iteration
374 k_block_tile = load_tile(k_dram_window);
375
377 {
378 __builtin_amdgcn_sched_barrier(
379 0); // prevent from messing up the order of global loads
380 }
381 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
383 {
384 __builtin_amdgcn_sched_barrier(
385 0); // prevent from messing up the order of global loads
386 }
387
388 if constexpr(k0_loops > 2)
389 {
390 static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
392 gemm_0(s_acc,
393 get_slice_tile(q_tile,
394 sequence<0, i_k0 * kK0>{},
395 sequence<kM0, (i_k0 + 1) * kK0>{}),
396 k_lds_window);
398 move_tile_window(k_dram_window, {0, kK0});
399
401 k_lds_window,
402 tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
403 k_block_tile = load_tile(k_dram_window); // global read i + 2
404 });
405 }
406
407 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
408 { // tail
410 gemm_0(s_acc,
411 get_slice_tile(q_tile,
412 sequence<0, (k0_loops - 2) * kK0>{},
413 sequence<kM0, (k0_loops - 1) * kK0>{}),
414 k_lds_window);
416
417 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
419
420 gemm_0(s_acc,
421 get_slice_tile(q_tile,
422 sequence<0, (k0_loops - 1) * kK0>{},
423 sequence<kM0, k0_loops * kK0>{}),
424 k_lds_window);
425 }
426
427 // STAGE 2, scale_s, add bias, mask, softmax
429 {
430 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
431 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
433 [&](auto& x, const auto& y) {
434#if !CK_TILE_FMHA_FWD_FAST_EXP2
435 x += type_convert<SaccDataType>(bias_element_func(y));
436#else
438 type_convert<SaccDataType>(bias_element_func(y));
439#endif
440 },
441 s_acc,
442 bias_tile);
443 }
444 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
445 {
446 const auto k_origin = k_page_block_navigator.to_global_window_origin(
447 i_page_block_k, k_dram_block_window.get_window_origin());
448 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
449 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
450 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
451 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
452 const auto tile_idx = get_x_indices_from_distributed_indices(
453 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
454
455 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
456 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
457 constexpr auto i_j_idx = make_tuple(idx0, idx1);
458
459 s_acc(i_j_idx) *= scale_s;
460 // position_encoding accept only logical coordinates, do conversion here
461 position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
462 });
463 });
464 }
465 else
466 {
467 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
468 if constexpr(kHasLogitsSoftCap)
469 {
470 auto apply_logits_transform =
471 [&variant, &variant_params, &block_indices](auto& x) {
472 x = variant.LogitsTransform(variant_params,
473 variant.QueryTransform(variant_params, x),
474 block_indices.batch_idx,
475 block_indices.qo_head_idx,
476 block_indices.kv_head_idx);
477 };
478#if !CK_TILE_FMHA_FWD_FAST_EXP2
479 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
480 {
481 apply_logits_transform(s_acc.thread_buf_[i]);
482 }
483#else
484 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
485 {
486 apply_logits_transform(s_acc.thread_buf_[i]);
487 }
488#endif
489 }
490 else
491 {
492#if !CK_TILE_FMHA_FWD_FAST_EXP2
493 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
494#endif
495 }
496 }
497 move_tile_window(bias_dram_window, {0, kN0});
498
500 if constexpr(kHasUnevenSplits)
501 {
502 const auto k_origin = k_page_block_navigator.to_global_window_origin(
503 i_page_block_k, k_dram_block_window.get_window_origin());
505 s_acc,
507 [&,
508 physical_seqlen_k_start_ = physical_seqlen_k_start,
509 physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
510 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
511 if constexpr(kIsPagedKV)
512 {
513 return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
514 }
515 else
516 {
517 return physical_seqlen_k_end_ <= col;
518 }
519 });
520 }
521
522 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
523 {
524 const auto k_origin = k_page_block_navigator.to_global_window_origin(
525 i_page_block_k, k_dram_block_window.get_window_origin());
526 // mask accept only logical coordinates, do conversion here
527 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
528 k_origin.at(number<0>{}) - kv_l2p_offset,
529 number<kM0>{},
530 number<kN0>{});
531 if(need_perpixel_check)
532 {
534 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
535 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
536 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
537 return mask.IsOutOfBound(row, col - kv_l2p_offset);
538 });
539 }
540 }
541
542 __builtin_amdgcn_sched_barrier(0);
543
544 // load the first tile for next iteration
545 if(i_total_loops < num_total_loop - 1)
546 {
547 // move K tile windows
548 i_page_block_k = k_page_block_navigator.move_tile_window(
549 i_page_block_k, k_dram_block_window, {kN0, 0});
550
551 k_dram_window = make_tile_window(
552 k_dram_block_window,
553 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
554
555 // laod the first tile of the first iteration and store to LDS
556 k_block_tile = load_tile(k_dram_window);
557 }
558
559 __builtin_amdgcn_sched_barrier(0);
560
561 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
562
563 // shuffle through LDS so that the tile layout is consistent with required by Gemm1
564 store_tile(s_write_lds_window, s);
566 auto s_new = load_tile(s_read_lds_window);
567
569 s_new,
570 sequence<1>{},
571 f_max,
572 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
574
575 const auto m_old = m; // m{j-1}
577 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
578
580 s_new.get_tile_distribution()); // Pcompute{j}
581
582 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
586 FmhaMask::IsMasking)
587 {
590 : raw_m;
591 }
592 else
593 {
594 return raw_m;
595 }
596 };
597
598 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
599 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
600 constexpr auto i_idx = make_tuple(idx0);
601#if CK_TILE_FMHA_FWD_FAST_EXP2
602 auto row_max = scale_s * get_validated_m(m[i_idx]);
603#endif
604 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
605 constexpr auto i_j_idx = make_tuple(idx0, idx1);
606#if CK_TILE_FMHA_FWD_FAST_EXP2
609 {
610 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
611 }
612 else
613 {
614 if constexpr(kHasLogitsSoftCap)
615 {
616 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
617 }
618 else
619 {
620 p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
621 }
622 }
623#else
624 p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
625#endif
626 });
627 });
628
630 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
631
633
634 const auto p =
635 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
636
637 // l{j}, Oacc{j}
638 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
639 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
640 constexpr auto i_idx = make_tuple(idx0);
641#if CK_TILE_FMHA_FWD_FAST_EXP2
642 const auto tmp = [&]() {
645 {
646 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
647 }
648 else
649 {
650 if constexpr(kHasLogitsSoftCap)
651 {
652 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
653 }
654 else
655 {
656 auto row_max = scale_s * get_validated_m(m[i_idx]);
657 return exp2(scale_s * m_old[i_idx] - row_max);
658 }
659 }
660 }();
661#else
662 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
663#endif
664 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
665 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
666 constexpr auto i_j_idx = make_tuple(idx0, idx1);
667 // FIXME: this use different equation from FA v2 paper,
668 // but produce correc result.
669 // Is the equation wrong?
670 o_acc(i_j_idx) *= tmp;
671 });
672 });
673
675 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
676 {
678 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
679 shuffle_tile(v_shuffle_tmp, v_prefetch);
681 v_lds_window,
682 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
683 }
684 else
685 {
686 store_tile(v_lds_window,
687 tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
688 }
689 i_page_block_v =
690 v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
691
692 // STAGE 3, KV gemm
693 if constexpr(k1_loops > 1)
694 {
695 static_for<0, k1_loops - 1, 1>{}([&,
696 &i_page_block_v_ = i_page_block_v,
697 &v_dram_window_ = v_dram_window](auto i_k1) {
698 const auto v = load_tile(v_dram_window_); // load next v
700
701 gemm_1(o_acc,
703 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
704 v_lds_window);
706
707 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
708 {
710 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
711 shuffle_tile(v_shuffle_tmp, v);
712 store_tile(v_lds_window,
713 tile_elementwise_in(v_element_func,
714 v_shuffle_tmp)); // store the prefetch
715 }
716 else
717 {
718 store_tile(v_lds_window,
719 tile_elementwise_in(v_element_func, v)); // store next v
720 }
721 i_page_block_v_ = v_page_block_navigator.move_tile_window(
722 i_page_block_v_, v_dram_window_, {0, kK1});
723 });
724 }
725
726 // tail
727 {
729 gemm_1(o_acc,
731 p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
732 v_lds_window);
734 }
735
736 __builtin_amdgcn_sched_barrier(0);
737
738 // load the first tile for next iteration
739 if(i_total_loops < num_total_loop - 1)
740 {
741 // store the first tile for next iteration to LDS
742 // moving k_dram_window is an in-page-block operation, so there is
743 // no need to invoke k_page_block_navigator.move_tile_window() here.
744 move_tile_window(k_dram_window, {0, kK0});
745 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
746 }
747 } while(++i_total_loops < num_total_loop);
748
749 if constexpr(kStoreLSE)
750 {
751 // store lse acc
752 auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
753
754 constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
755 sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
756 constexpr auto i_idx = make_tuple(idx0);
757#if CK_TILE_FMHA_FWD_FAST_EXP2
760 {
761 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
762 }
763 else
764 {
765 if constexpr(kHasLogitsSoftCap)
766 {
767 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
768 }
769 else
770 {
771 lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
772 }
773 }
774#else
775 lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
776#endif
777 });
778
780 {
781 store_tile(lse_acc_dram_window_tmp,
782 tile_elementwise_in(lse_acc_element_func, lse_acc));
783 }
784 }
785
786 // finally, O
787 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
788
789 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
790 constexpr auto i_idx = make_tuple(idx0);
791 const auto tmp = [&]() {
793 FmhaMask::IsMasking)
794 {
795 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
796 }
797 else
798 return 1 / l[i_idx];
799 }();
800 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
801 constexpr auto i_j_idx = make_tuple(idx0, idx1);
802 o_acc(i_j_idx) *= tmp;
803 });
804 });
805
806 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
807
808 return o_acc;
809 }
810
811 template <typename QDramBlockWindowTmp,
812 typename KDramBlockWindowLengths,
813 typename KPageBlockNavigator,
814 typename VDramBlockWindowLengths,
815 typename VPageBlockNavigator,
816 typename BiasDramBlockWindowTmp,
817 typename LSEaccDramBlockWindowTmp,
818 typename PositionEncoding,
819 typename AttentionVariantParams,
820 typename BlockIndices>
822 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
823 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
824 const KPageBlockNavigator& k_page_block_navigator,
825 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
826 const VPageBlockNavigator& v_page_block_navigator,
827 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
828 LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
829 index_t num_splits,
830 index_t i_split,
831 FmhaMask mask,
832 PositionEncoding position_encoding,
833 float scale_s,
834 const AttentionVariant& variant,
835 const AttentionVariantParams& variant_params,
836 const BlockIndices& block_indices,
837 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
838 void* smem_ptr) const
839 {
840 return operator()(q_dram_block_window_tmp,
841 identity{},
842 k_dram_block_window_lengths,
843 k_page_block_navigator,
844 identity{},
845 v_dram_block_window_lengths,
846 v_page_block_navigator,
847 identity{},
848 bias_dram_block_window_tmp,
849 identity{},
850 lse_acc_dram_block_window_tmp,
851 identity{},
852 identity{},
853 identity{},
854 identity{},
855 num_splits,
856 i_split,
857 mask,
858 position_encoding,
859 scale_s,
860 variant,
861 variant_params,
862 block_indices,
863 kv_l2p_offset,
864 smem_ptr);
865 }
866};
867
868} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE index_t get_thread_local_1d_id()
Definition arch.hpp:94
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:19
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:17
static constexpr index_t kN1
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:43
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:29
static constexpr index_t kAlignmentBias
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:82
static constexpr const char * name
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:116
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:822
static constexpr index_t kBlockSize
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:38
static constexpr index_t kSubQKHeaddim
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:46
static constexpr bool kIsPagedKV
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:58
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:26
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:55
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:20
static constexpr index_t kN0
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:41
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:21
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:27
static constexpr index_t kK1
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:44
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:51
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:24
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:23
static constexpr index_t kAlignmentV
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:72
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:18
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:33
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:45
static constexpr index_t kK0
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:42
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:30
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:28
static constexpr index_t kAlignmentQ
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:68
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:34
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:19
static constexpr index_t kM0
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:40
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:50
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:85
static constexpr bool kQLoadOnce
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:35
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:53
static constexpr bool kStoreLSE
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:57
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:52
static constexpr auto BiasEnum
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:56
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:118
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:22
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:54
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const KElementFunction &k_element_func, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, const LSEaccElementFunction &lse_acc_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:142
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:31
static constexpr index_t kAlignmentK
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:70
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:25
static constexpr index_t kAlignmentOacc
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:79
static constexpr bool kHasUnevenSplits
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp:59
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469