block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp Source File

block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp Source File#

Composable Kernel: block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp Source File
block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// TODO: This class is a variant of the existing BlockFmhaFwdSplitKVPipelineQRKSVS pipeline.
14// Refactoring to extract shared logic is recommended as future work.
15
16template <typename Problem_, typename Policy_ = BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy>
18{
33
36 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
37 static_assert(kQLoadOnce == Policy::QLoadOnce);
38
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kN1 = BlockFmhaShape::kN1;
45 static constexpr index_t kK1 = BlockFmhaShape::kK1;
46 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
47 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
48
49 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
50
51 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
52 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
53 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
54 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
55 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
56 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
57 static constexpr auto BiasEnum = Problem::BiasEnum;
58 static constexpr bool kStoreLSE = Problem::kStoreLSE;
59 static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
60
61 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
62 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
65
66 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
67 // ... together with tensor distribution. tensor dist should able to overwrite this
68 static constexpr index_t kAlignmentQ =
69 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
70 static constexpr index_t kAlignmentK =
71 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
72 static constexpr index_t kAlignmentV = []() {
73 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
74 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
75 else
76 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
77 }();
78
79 static constexpr index_t kAlignmentO =
80 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
81 static constexpr index_t kAlignmentBias =
82 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
83
84 static constexpr index_t kBlockPerCu = []() {
85 if constexpr(Problem::kBlockPerCu != -1)
86 return Problem::kBlockPerCu;
87 else
88 {
89 if constexpr(kQKHeaddim <= 32)
90 {
91 return 2;
92 }
93 else if constexpr(kQKHeaddim <= 64)
94 {
95 return 3;
96 }
97 else if constexpr(kQKHeaddim <= 128)
98 {
100 return 1;
101 else
102 return 2;
103 }
104 else if constexpr(kQKHeaddim <= 256)
105 {
106 return 1;
107 }
108 else
109 {
110 return 1;
111 }
112 }
113 }();
114
115 static constexpr const char* name = "qr_pagedkv";
116
118 {
119 return Policy::template GetSmemSize<Problem>();
120 }
121
122 template <typename QDramBlockWindowTmp,
123 typename KDramBlockWindowLengths,
124 typename KPageBlockNavigator,
125 typename VDramBlockWindowLengths,
126 typename VPageBlockNavigator,
127 typename BiasDramBlockWindowTmp,
128 typename LSEDramBlockWindowTmp,
129 typename QElementFunction,
130 typename KElementFunction,
131 typename VElementFunction,
132 typename BiasElementFunction,
133 typename LSEElementFunction,
134 typename SAccElementFunction,
135 typename PComputeElementFunction,
136 typename OAccElementFunction,
137 typename PositionEncoding,
138 typename AttentionVariantParams,
139 typename BlockIndices>
141 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
142 const QElementFunction& q_element_func,
143 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
144 const KPageBlockNavigator& k_page_block_navigator,
145 const KElementFunction& k_element_func,
146 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
147 const VPageBlockNavigator& v_page_block_navigator,
148 const VElementFunction& v_element_func,
149 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
150 const BiasElementFunction& bias_element_func,
151 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
152 const LSEElementFunction& lse_element_func,
153 const SAccElementFunction& s_acc_element_func,
154 const PComputeElementFunction& p_compute_element_func,
155 const OAccElementFunction& o_acc_element_func,
156 FmhaMask mask,
157 PositionEncoding position_encoding,
158 float scale_s,
159 const AttentionVariant& variant,
160 const AttentionVariantParams& variant_params,
161 const BlockIndices& block_indices,
162 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
163 void* smem_ptr) const
164 {
165 static_assert(
166 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
167 std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
168 std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
169 "wrong!");
170
171 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
172 kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
173 kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
174 kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
175 kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
176 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
177 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
178 "wrong!");
179
180 // K tile in LDS
181 KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
182 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
184 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
185 auto k_lds_window =
187
188 // V tile in LDS
190 reinterpret_cast<VDataType*>(smem_ptr),
191 Policy::template MakeVLdsBlockDescriptor<Problem>());
192 auto v_lds_window = make_tile_window(
193 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
194
195 // Block GEMM
196 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
197 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
198
199 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
200 q_dram_block_window_tmp.get_window_lengths(),
201 q_dram_block_window_tmp.get_window_origin(),
202 Policy::template MakeQRegTileDistribution<Problem>());
203
204 auto q = load_tile(q_dram_window);
205
206 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
207 auto s_acc = SaccBlockTileType{};
208
209 // reduction function for softmax
210 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
211 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
212
213 // infer Sacc, S, P, M, L, Oacc type
214 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
215
216 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
217 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
218
219 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
220
221 // init Oacc, M, L
222 auto o_acc = OaccBlockTileType{};
223 auto m = MLBlockTileType{};
224 auto l = MLBlockTileType{};
225
226 clear_tile(o_acc);
228 clear_tile(l);
229
230 const auto q_origin = q_dram_window.get_window_origin();
231 const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
232 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
233
234 // check early exit if no work to do
235 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
236 {
237 const auto num_total_loop =
238 integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
239 if(num_total_loop <= 0)
240 {
241 if constexpr(kStoreLSE)
242 {
243 auto lse =
244 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
245
247
248 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
249 }
250
251 // Note: here occ are all cleard, return it
252 // Note: q loaded but no fence, ignore it.
253 return o_acc;
254 }
255 }
256
257 // k_dram_block_window
258 const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
259 const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
260 // make sure the first tile is completely located in page-block (page-block size should be
261 // divisible by kN0)
262 // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
263 // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
264 const index_t aligned_physical_seqlen_k_start =
265 [&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
266 if constexpr(kIsPagedKV)
267 {
268 return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
269 }
270 else
271 {
272 return physical_seqlen_k_start_;
273 }
274 }();
275 const index_t num_total_loop =
276 integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
277
278 auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
279 k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
280
281 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
282 auto bias_dram_window =
283 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
284 bias_dram_block_window_tmp.get_window_lengths(),
285 {bias_origin.at(number<0>{}),
286 logical_seqlen_k_start - (physical_seqlen_k_start -
287 aligned_physical_seqlen_k_start)}, // M/N
288 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
289
290 // v_dram_window
291 auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
292 v_dram_block_window_lengths,
293 {0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
294 Policy::template MakeVDramTileDistribution<Problem>());
295
296 auto q_tile = tile_elementwise_in(q_element_func, q);
297
298 // prefetch K tile
299 index_t i_total_loops = 0;
300 constexpr index_t k0_loops = kQKHeaddim / kK0;
301 constexpr index_t k1_loops = kN0 / kK1;
302
303 static_assert(2 <= k0_loops);
304 static_assert(1 <= k1_loops);
305 do
306 {
307 // STAGE 1, QK gemm
308 auto k_dram_window = make_tile_window(
309 k_dram_block_window,
310 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
311 // load
312
313 auto k_block_tile = load_tile(k_dram_window);
314 {
315 // moving k_dram_window is an in-page-block operation, so there is
316 // no need to invoke k_page_block_navigator.move_tile_window() here.
317 move_tile_window(k_dram_window, {0, kK0});
318 clear_tile(s_acc); // initialize C
319 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
320 k_block_tile = load_tile(k_dram_window);
321 }
322 auto physical_next_block_id_k =
323 amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
324 i_page_block_k, k_dram_block_window, {kN0, 0}));
325 auto physical_next_block_id_v = amd_wave_read_first_lane(
326 v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
327
329 {
330 __builtin_amdgcn_sched_barrier(
331 0); // prevent from messing up the order of global loads
332 }
333 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
335 {
336 __builtin_amdgcn_sched_barrier(
337 0); // prevent from messing up the order of global loads
338 }
339
340 if constexpr(k0_loops > 2)
341 {
342 static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
344 gemm_0(s_acc,
345 get_slice_tile(q_tile,
346 sequence<0, i_k0 * kK0>{},
347 sequence<kM0, (i_k0 + 1) * kK0>{}),
348 k_lds_window);
350 move_tile_window(k_dram_window, {0, kK0});
351
353 k_lds_window,
354 tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
355 k_block_tile = load_tile(k_dram_window); // global read i + 2
356 });
357 }
358
359 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
360 { // tail
362 gemm_0(s_acc,
363 get_slice_tile(q_tile,
364 sequence<0, (k0_loops - 2) * kK0>{},
365 sequence<kM0, (k0_loops - 1) * kK0>{}),
366 k_lds_window);
368
369 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
371
372 gemm_0(s_acc,
373 get_slice_tile(q_tile,
374 sequence<0, (k0_loops - 1) * kK0>{},
375 sequence<kM0, k0_loops * kK0>{}),
376 k_lds_window);
377 }
378
379 // STAGE 2, scale_s, add bias, mask, softmax
381 {
382 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
383 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
385 [&](auto& x, const auto& y) {
386#if !CK_TILE_FMHA_FWD_FAST_EXP2
387 x += type_convert<SaccDataType>(bias_element_func(y));
388#else
390 type_convert<SaccDataType>(bias_element_func(y));
391#endif
392 },
393 s_acc,
394 bias_tile);
395 }
396 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
397 {
398 const auto k_origin = k_page_block_navigator.to_global_window_origin(
399 i_page_block_k, k_dram_block_window.get_window_origin());
400 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
401 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
402 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
403 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
404 const auto tile_idx = get_x_indices_from_distributed_indices(
405 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
406
407 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
408 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
409 constexpr auto i_j_idx = make_tuple(idx0, idx1);
410
411 s_acc(i_j_idx) *= scale_s;
412 // position_encoding accept only logical coordinates, do conversion here
413 position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
414 });
415 });
416 }
417 else
418 {
419 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
420 if constexpr(kHasLogitsSoftCap)
421 {
422 auto apply_logits_transform =
423 [&variant, &variant_params, &block_indices](auto& x) {
424 x = variant.LogitsTransform(variant_params,
425 variant.QueryTransform(variant_params, x),
426 block_indices.batch_idx,
427 block_indices.qo_head_idx,
428 block_indices.kv_head_idx);
429 };
430#if !CK_TILE_FMHA_FWD_FAST_EXP2
431 tile_elementwise_inout(apply_logits_transform, s_acc);
432#else
433 tile_elementwise_inout(apply_logits_transform, s_acc);
434#endif
435 }
436 else
437 {
438#if !CK_TILE_FMHA_FWD_FAST_EXP2
439 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
440#endif
441 }
442 }
443 move_tile_window(bias_dram_window, {0, kN0});
444
445 {
446 const auto k_origin = k_page_block_navigator.to_global_window_origin(
447 i_page_block_k, k_dram_block_window.get_window_origin());
448
449 if constexpr(kIsPagedKV)
450 {
451 // check columns in [aligned_physical_seqlen_k_start, physical_seqlen_k_end)
452 if(kv_l2p_offset > 0)
453 {
455 s_acc,
457 [&, physical_seqlen_k_start_ = physical_seqlen_k_start](auto tile_idx) {
458 const auto col =
459 k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
460 return col < physical_seqlen_k_start_;
461 });
462 };
463 }
464
465 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
466 {
467 // mask accept only logical coordinates, do conversion here
468 bool need_perpixel_check =
469 mask.IsEdgeTile(q_origin.at(number<0>{}),
470 k_origin.at(number<0>{}) - kv_l2p_offset,
471 number<kM0>{},
472 number<kN0>{});
473 if(need_perpixel_check)
474 {
476 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
477 const auto row =
478 q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
479 const auto col =
480 k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
481 return mask.IsOutOfBound(row, col - kv_l2p_offset);
482 });
483 }
484 }
485 }
486
487 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
489 s,
490 sequence<1>{},
491 f_max,
492 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
494
495 const auto m_old = m; // m{j-1}
497 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
498
500 s.get_tile_distribution()); // Pcompute{j}
501
502 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
506 FmhaMask::IsMasking)
507 {
510 : raw_m;
511 }
512 else
513 {
514 return raw_m;
515 }
516 };
517
518 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
519 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
520 constexpr auto i_idx = make_tuple(idx0);
521#if CK_TILE_FMHA_FWD_FAST_EXP2
522 auto row_max = scale_s * get_validated_m(m[i_idx]);
523#endif
524 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
525 constexpr auto i_j_idx = make_tuple(idx0, idx1);
526#if CK_TILE_FMHA_FWD_FAST_EXP2
529 {
530 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
531 }
532 else
533 {
534 if constexpr(kHasLogitsSoftCap)
535 {
536 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
537 }
538 else
539 {
540 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
541 }
542 }
543#else
544 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
545#endif
546 });
547 });
548
550 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
551
553 // l{j}, Oacc{j}
554 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
555 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
556 constexpr auto i_idx = make_tuple(idx0);
557#if CK_TILE_FMHA_FWD_FAST_EXP2
558 const auto tmp = [&]() {
561 {
562 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
563 }
564 else
565 {
566 if constexpr(kHasLogitsSoftCap)
567 {
568
569 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
570 }
571 else
572 {
573 auto row_max = scale_s * get_validated_m(m[i_idx]);
574 return exp2(scale_s * m_old[i_idx] - row_max);
575 }
576 }
577 }();
578#else
579 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
580#endif
581 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
582 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
583 constexpr auto i_j_idx = make_tuple(idx0, idx1);
584 // FIXME: this use different equation from FA v2 paper,
585 // but produce correc result.
586 // Is the equation wrong?
587 o_acc(i_j_idx) *= tmp;
588 });
589 });
590
592 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
593 {
595 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
596 shuffle_tile(v_shuffle_tmp, v_prefetch);
598 v_lds_window,
599 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
600 }
601 else
602 {
603 store_tile(v_lds_window,
604 tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
605 }
606 i_page_block_v = v_page_block_navigator.move_tile_window(
607 i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
608
609 const auto p =
610 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
611
612 // STAGE 3, KV gemm
613 if constexpr(k1_loops > 1)
614 {
615 static_for<0, k1_loops - 1, 1>{}([&,
616 &i_page_block_v_ = i_page_block_v,
617 &v_dram_window_ = v_dram_window](auto i_k1) {
618 auto physical_next_block_id_v_ =
619 __builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id(
620 i_page_block_v_, v_dram_window_, {0, kK1}));
621 const auto v = load_tile(v_dram_window_); // load next v
623 gemm_1(o_acc,
625 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
626 v_lds_window);
628 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
629 {
631 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
632 shuffle_tile(v_shuffle_tmp, v);
633 store_tile(v_lds_window,
634 tile_elementwise_in(v_element_func,
635 v_shuffle_tmp)); // store the prefetch
636 }
637 else
638 {
639 store_tile(v_lds_window,
640 tile_elementwise_in(v_element_func, v)); // store next v
641 }
642 i_page_block_v_ = v_page_block_navigator.move_tile_window(
643 i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
644 });
645 }
646 // move K tile windows
647 i_page_block_k = k_page_block_navigator.move_tile_window(
648 i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
649 // tail
650 {
652 gemm_1(o_acc,
653 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
654 v_lds_window);
656 }
657 } while(++i_total_loops < num_total_loop);
658
659 // store lse
660 if constexpr(kStoreLSE)
661 {
662 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
663
664 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
665 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
666 constexpr auto i_idx = make_tuple(idx0);
667#if CK_TILE_FMHA_FWD_FAST_EXP2
670 {
671 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
672 }
673 else
674 {
675 if constexpr(kHasLogitsSoftCap)
676 {
677 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
678 }
679 else
680 {
681 lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
682 }
683 }
684#else
685 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
686#endif
687 });
688
689 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
690 }
691
692 // finally, O
693 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
694
695 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
696 constexpr auto i_idx = make_tuple(idx0);
697 const auto tmp = [&]() {
698 if constexpr(FmhaMask::IsMasking)
699 {
700 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
701 }
702 else
703 return 1 / l[i_idx];
704 }();
705 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
706 constexpr auto i_j_idx = make_tuple(idx0, idx1);
707 o_acc(i_j_idx) *= tmp;
708 });
709 });
710
711 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
712
713 return o_acc;
714 }
715
716 template <typename QDramBlockWindowTmp,
717 typename KDramBlockWindowLengths,
718 typename KPageBlockNavigator,
719 typename VDramBlockWindowLengths,
720 typename VPageBlockNavigator,
721 typename BiasDramBlockWindowTmp,
722 typename LSEDramBlockWindowTmp,
723 typename PositionEncoding,
724 typename AttentionVariantParams,
725 typename BlockIndices>
727 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
728 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
729 const KPageBlockNavigator& k_page_block_navigator,
730 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
731 const VPageBlockNavigator& v_page_block_navigator,
732 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
733 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
734 FmhaMask mask,
735 PositionEncoding position_encoding,
736 float scale_s,
737 const AttentionVariant& variant,
738 const AttentionVariantParams& variant_params,
739 const BlockIndices& block_indices,
740 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
741 void* smem_ptr) const
742 {
743 return operator()(q_dram_block_window_tmp,
744 identity{},
745 k_dram_block_window_lengths,
746 k_page_block_navigator,
747 identity{},
748 v_dram_block_window_lengths,
749 v_page_block_navigator,
750 identity{},
751 bias_dram_block_window_tmp,
752 identity{},
753 lse_dram_block_window_tmp,
754 identity{},
755 identity{},
756 identity{},
757 identity{},
758 mask,
759 position_encoding,
760 scale_s,
761 variant,
762 variant_params,
763 block_indices,
764 kv_l2p_offset,
765 smem_ptr);
766 }
767};
768
769} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:18
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:56
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:52
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:31
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:25
static constexpr index_t kAlignmentQ
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:68
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:21
static constexpr index_t kM0
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:41
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:54
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:19
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:35
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:32
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:84
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:22
static constexpr index_t kN0
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:42
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:28
static constexpr auto BiasEnum
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:57
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:23
static constexpr index_t kBlockSize
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:39
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:27
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:20
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:51
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:26
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr bool kQLoadOnce
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:36
static constexpr index_t kK1
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:45
static constexpr index_t kAlignmentK
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:70
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:117
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:34
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const KElementFunction &k_element_func, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:141
static constexpr bool kIsPagedKV
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:59
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:53
static constexpr index_t kAlignmentBias
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:81
static constexpr index_t kK0
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:43
static constexpr index_t kN1
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:44
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:30
static constexpr index_t kAlignmentO
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:79
static constexpr index_t kAlignmentV
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:72
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:55
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:29
static constexpr const char * name
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:115
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:24
static constexpr index_t kSubQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:47
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:727
static constexpr bool kStoreLSE
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:58
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469