fmha_fwd_kernel.hpp Source File

fmha_fwd_kernel.hpp Source File#

Composable Kernel: fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11#include <string>
12#include <type_traits>
13#include <utility>
14#include <variant>
15
16#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
17// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
18// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
19// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
20// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
21// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
22
23namespace ck_tile {
24
25template <typename FmhaPipeline_, typename EpiloguePipeline_>
27{
30 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
31
32 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33 static_assert(kBlockPerCu > 0);
34 static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35
45
47
48 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
49 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
50 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
51 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
52 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
53 static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
54 static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
55 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
56 static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
57 static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
58 static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
59
62 static constexpr bool kHasMask = FmhaMask::IsMasking;
63
64 static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65
66 static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
67#if defined(__gfx950__)
68 static constexpr bool kIsAvailable = true;
69#else
70 static constexpr bool kIsAvailable = !kUseTrLoad;
71#endif
72 static constexpr std::string_view kPipelineName = FmhaPipeline::name;
73
74 // clang-format off
75 template <typename T1, typename T2 = T1> struct t2s;
76 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
77 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
78 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
79 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
80 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
81 template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
82 template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
83 // clang-format on
84
85 CK_TILE_HOST static std::string GetName()
86 {
87 // sync with generate.py
88 // clang-format off
89 using bfs = typename FmhaPipeline::BlockFmhaShape;
90 using g0br = typename bfs::Gemm0BlockWarps;
91 using g1br = typename bfs::Gemm1BlockWarps;
92 using g0wt = typename bfs::Gemm0WarpTile;
93 using g1wt = typename bfs::Gemm1WarpTile;
94 #define _SS_ std::string
95 #define _TS_ std::to_string
96 auto pn = [&] () {
97 std::string n;
98 if (kPadSeqLenQ) n += "s";
99 if (kPadSeqLenK) n += "sk";
100 if (kPadHeadDimQ) n += "d";
101 if (kPadHeadDimV) n += "dv";
102 return n.empty() ? n : std::string("p") + n; }();
103 return
104 _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
105 "_" + (kIsGroupMode ? "group" : "batch") + "_"
106 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
107 _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
108 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
109 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
110 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
111 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
112 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
113 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
115 (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
116 #undef _SS_
117 #undef _TS_
118 // clang-format on
119 }
120
121 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
122 // arg
124 {
125 };
126
127 // kargs use aggregate initializer, so no constructor will provided
128 // use inheritance to minimize karg size
129 // user need to use MakeKargs() function to create kargs.
158
160 {
162
163 void init_logits_soft_cap(float logits_soft_cap_)
164 {
165 if(0 < logits_soft_cap_)
166 {
167 logits_soft_cap = logits_soft_cap_;
169 }
170 else
171 {
172 logits_soft_cap = 0.f;
174 }
175 }
176
179 };
180
187
192
194 {
195 // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
196 const void* alibi_slope_ptr;
197 ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
198 };
199
201 {
202 // ck_tile::index_t window_size_left, window_size_right;
205 };
206
208 {
209 float scale_p;
210 float scale_o;
211 };
212
219
233
235 {
236 void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
237 {
238 float p_undrop = 1.0 - p_drop;
240 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
241 rp_undrop = 1.0 / p_undrop;
242
243 this->drop_seed.val = seed;
244 this->drop_offset.val = offset;
246 }
247
248 void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
249 {
250 float p_undrop = 1.0 - p_drop;
252 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
253 rp_undrop = 1.0 / p_undrop;
254
255 this->drop_seed.ptr = seed_ptr;
256 this->drop_offset.ptr = offset_ptr;
257 this->is_drop_seed_offset_from_host = false;
258 }
259
260 float rp_undrop = 1;
261 uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
262 bool is_store_randval = false;
263 void* rand_val_ptr = nullptr;
264
267 };
268
273
278
281 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
282 FmhaFwdBatchModeBiasKargs,
283 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
284 FmhaFwdAlibiKargs,
285 FmhaFwdEmptyKargs<0>>>,
286 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
287 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
288 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
289 std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
290 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
291 {
296
297 // Optional cumulative sequence length pointers for batch mode
298 // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
299 const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
300 const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
301 };
302
305 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
306 FmhaFwdCommonBiasKargs,
307 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
308 FmhaFwdAlibiKargs,
309 FmhaFwdEmptyKargs<0>>>,
310 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
311 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
312 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
313 std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
314 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
315 std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
316 {
321
322 // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
323 const int32_t* cu_seqlen_q_ptr = nullptr;
324 const int32_t* cu_seqlen_k_ptr = nullptr;
325 };
326
327 using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
328
335
336 template <bool Cond = !kIsGroupMode>
337 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
338 MakeKargsImpl(const void* q_ptr,
339 const void* k_ptr,
340 const void* v_ptr,
341 const void* bias_ptr,
342 void* rand_val_ptr,
343 void* lse_ptr,
344 void* o_ptr,
345 ck_tile::index_t seqlen_q,
346 ck_tile::index_t seqlen_k,
347 ck_tile::index_t hdim_q,
348 ck_tile::index_t hdim_v,
349 ck_tile::index_t num_head_q,
350 ck_tile::index_t nhead_ratio_qk,
351 float scale_s,
352 float scale_p,
353 float scale_o,
354 float logits_soft_cap,
355 ck_tile::index_t stride_q,
356 ck_tile::index_t stride_k,
357 ck_tile::index_t stride_v,
358 ck_tile::index_t stride_bias,
359 ck_tile::index_t stride_randval,
360 ck_tile::index_t stride_o,
361 ck_tile::index_t nhead_stride_q,
362 ck_tile::index_t nhead_stride_k,
363 ck_tile::index_t nhead_stride_v,
364 ck_tile::index_t nhead_stride_bias,
365 ck_tile::index_t nhead_stride_randval,
366 ck_tile::index_t nhead_stride_lse,
367 ck_tile::index_t nhead_stride_o,
368 ck_tile::index_t batch_stride_q,
369 ck_tile::index_t batch_stride_k,
370 ck_tile::index_t batch_stride_v,
371 ck_tile::index_t batch_stride_bias,
372 ck_tile::index_t batch_stride_randval,
373 ck_tile::index_t batch_stride_lse,
374 ck_tile::index_t batch_stride_o,
375 ck_tile::index_t window_size_left,
376 ck_tile::index_t window_size_right,
377 ck_tile::index_t mask_type,
378 float p_drop,
379 bool s_randval,
380 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
381 drop_seed_offset,
382 const void* cu_seqlen_q_ptr = nullptr,
383 const void* cu_seqlen_k_ptr = nullptr)
384 {
385 Kargs kargs{{q_ptr,
386 k_ptr,
387 v_ptr,
388 o_ptr,
389 seqlen_q,
390 seqlen_k,
391 hdim_q,
392 hdim_v,
393 num_head_q,
394 nhead_ratio_qk,
395#if CK_TILE_FMHA_FWD_FAST_EXP2
396 static_cast<float>(scale_s * ck_tile::log2e_v<>),
397#else
398 scale_s,
399#endif
400 stride_q,
401 stride_k,
402 stride_v,
403 stride_o,
404 nhead_stride_q,
405 nhead_stride_k,
406 nhead_stride_v,
407 nhead_stride_o}, // args for common karg
408 {}, // placeholder for bias
409 {}, // placeholder for mask
410 {}, // placeholder for lse
411 {}, // placeholder for fp8_static_quant args
412 {}, // placeholder for dropout
413 {}, // placeholder for logits_soft_cap
414 batch_stride_q,
415 batch_stride_k,
416 batch_stride_v,
417 batch_stride_o};
418
420 {
421 kargs.bias_ptr = bias_ptr;
422 kargs.stride_bias = stride_bias;
423 kargs.nhead_stride_bias = nhead_stride_bias;
424 kargs.batch_stride_bias = batch_stride_bias;
425 }
426 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
427 {
428 kargs.alibi_slope_ptr = bias_ptr;
429 kargs.alibi_slope_stride = stride_bias;
430 }
431 if constexpr(kHasMask)
432 {
433 kargs.window_size_left = window_size_left;
434 kargs.window_size_right = window_size_right;
435 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
436 }
437 if constexpr(kStoreLSE)
438 {
439 kargs.lse_ptr = lse_ptr;
440 kargs.nhead_stride_lse = nhead_stride_lse;
441 kargs.batch_stride_lse = batch_stride_lse;
442 }
443 if constexpr(kDoFp8StaticQuant)
444 {
445 kargs.scale_p = scale_p;
446 kargs.scale_o = scale_o;
447 }
448 if constexpr(kHasDropout)
449 {
450 if(drop_seed_offset.index() == 0) // seed & offset come from host
451 {
452 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
453 kargs.init_dropout(p_drop, seed, offset);
454 }
455 else // seed & offset come from device
456 {
457 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
458 kargs.init_dropout(p_drop,
459 reinterpret_cast<const uint64_t*>(seed_ptr),
460 reinterpret_cast<const uint64_t*>(offset_ptr));
461 }
462
463 kargs.rand_val_ptr = rand_val_ptr;
464 kargs.stride_randval = stride_randval;
465 kargs.nhead_stride_randval = nhead_stride_randval;
466 kargs.batch_stride_randval = batch_stride_randval;
467 kargs.is_store_randval = s_randval;
468 }
469 if constexpr(kHasLogitsSoftCap)
470 {
471 kargs.init_logits_soft_cap(logits_soft_cap);
472 }
473
474 kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
475 kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
476 return kargs;
477 }
478
479 // std::variant<> can't take in a list initializer, overload for backward compatibility
480 template <bool Cond = !kIsGroupMode>
481 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
482 MakeKargs(const void* q_ptr,
483 const void* k_ptr,
484 const void* v_ptr,
485 const void* bias_ptr,
486 void* rand_val_ptr,
487 void* lse_ptr,
488 void* o_ptr,
489 ck_tile::index_t seqlen_q,
490 ck_tile::index_t seqlen_k,
491 ck_tile::index_t hdim_q,
492 ck_tile::index_t hdim_v,
493 ck_tile::index_t num_head_q,
494 ck_tile::index_t nhead_ratio_qk,
495 float scale_s,
496 float scale_p,
497 float scale_o,
498 float logits_soft_cap,
499 ck_tile::index_t stride_q,
500 ck_tile::index_t stride_k,
501 ck_tile::index_t stride_v,
502 ck_tile::index_t stride_bias,
503 ck_tile::index_t stride_randval,
504 ck_tile::index_t stride_o,
505 ck_tile::index_t nhead_stride_q,
506 ck_tile::index_t nhead_stride_k,
507 ck_tile::index_t nhead_stride_v,
508 ck_tile::index_t nhead_stride_bias,
509 ck_tile::index_t nhead_stride_randval,
510 ck_tile::index_t nhead_stride_lse,
511 ck_tile::index_t nhead_stride_o,
512 ck_tile::index_t batch_stride_q,
513 ck_tile::index_t batch_stride_k,
514 ck_tile::index_t batch_stride_v,
515 ck_tile::index_t batch_stride_bias,
516 ck_tile::index_t batch_stride_randval,
517 ck_tile::index_t batch_stride_lse,
518 ck_tile::index_t batch_stride_o,
519 ck_tile::index_t window_size_left,
520 ck_tile::index_t window_size_right,
521 ck_tile::index_t mask_type,
522 float p_drop,
523 bool s_randval,
524 const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
525 const void* cu_seqlen_q_ptr = nullptr,
526 const void* cu_seqlen_k_ptr = nullptr)
527 {
528 return MakeKargsImpl(
529 q_ptr,
530 k_ptr,
531 v_ptr,
532 bias_ptr,
533 rand_val_ptr,
534 lse_ptr,
535 o_ptr,
536 seqlen_q,
537 seqlen_k,
538 hdim_q,
539 hdim_v,
540 num_head_q,
541 nhead_ratio_qk,
542 scale_s,
543 scale_p,
544 scale_o,
545 logits_soft_cap,
546 stride_q,
547 stride_k,
548 stride_v,
549 stride_bias,
550 stride_randval,
551 stride_o,
552 nhead_stride_q,
553 nhead_stride_k,
554 nhead_stride_v,
555 nhead_stride_bias,
556 nhead_stride_randval,
557 nhead_stride_lse,
558 nhead_stride_o,
559 batch_stride_q,
560 batch_stride_k,
561 batch_stride_v,
562 batch_stride_bias,
563 batch_stride_randval,
564 batch_stride_lse,
565 batch_stride_o,
566 window_size_left,
567 window_size_right,
568 mask_type,
569 p_drop,
570 s_randval,
571 std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
572 cu_seqlen_q_ptr,
573 cu_seqlen_k_ptr);
574 }
575
576 // std::variant<> can't take in a list initializer, overload for backward compatibility
577 template <bool Cond = !kIsGroupMode>
578 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
579 MakeKargs(const void* q_ptr,
580 const void* k_ptr,
581 const void* v_ptr,
582 const void* bias_ptr,
583 void* rand_val_ptr,
584 void* lse_ptr,
585 void* o_ptr,
586 ck_tile::index_t seqlen_q,
587 ck_tile::index_t seqlen_k,
588 ck_tile::index_t hdim_q,
589 ck_tile::index_t hdim_v,
590 ck_tile::index_t num_head_q,
591 ck_tile::index_t nhead_ratio_qk,
592 float scale_s,
593 float scale_p,
594 float scale_o,
595 float logits_soft_cap,
596 ck_tile::index_t stride_q,
597 ck_tile::index_t stride_k,
598 ck_tile::index_t stride_v,
599 ck_tile::index_t stride_bias,
600 ck_tile::index_t stride_randval,
601 ck_tile::index_t stride_o,
602 ck_tile::index_t nhead_stride_q,
603 ck_tile::index_t nhead_stride_k,
604 ck_tile::index_t nhead_stride_v,
605 ck_tile::index_t nhead_stride_bias,
606 ck_tile::index_t nhead_stride_randval,
607 ck_tile::index_t nhead_stride_lse,
608 ck_tile::index_t nhead_stride_o,
609 ck_tile::index_t batch_stride_q,
610 ck_tile::index_t batch_stride_k,
611 ck_tile::index_t batch_stride_v,
612 ck_tile::index_t batch_stride_bias,
613 ck_tile::index_t batch_stride_randval,
614 ck_tile::index_t batch_stride_lse,
615 ck_tile::index_t batch_stride_o,
616 ck_tile::index_t window_size_left,
617 ck_tile::index_t window_size_right,
618 ck_tile::index_t mask_type,
619 float p_drop,
620 bool s_randval,
621 const std::tuple<const void*, const void*>& drop_seed_offset,
622 const void* cu_seqlen_q_ptr = nullptr,
623 const void* cu_seqlen_k_ptr = nullptr)
624 {
625 return MakeKargsImpl(
626 q_ptr,
627 k_ptr,
628 v_ptr,
629 bias_ptr,
630 rand_val_ptr,
631 lse_ptr,
632 o_ptr,
633 seqlen_q,
634 seqlen_k,
635 hdim_q,
636 hdim_v,
637 num_head_q,
638 nhead_ratio_qk,
639 scale_s,
640 scale_p,
641 scale_o,
642 logits_soft_cap,
643 stride_q,
644 stride_k,
645 stride_v,
646 stride_bias,
647 stride_randval,
648 stride_o,
649 nhead_stride_q,
650 nhead_stride_k,
651 nhead_stride_v,
652 nhead_stride_bias,
653 nhead_stride_randval,
654 nhead_stride_lse,
655 nhead_stride_o,
656 batch_stride_q,
657 batch_stride_k,
658 batch_stride_v,
659 batch_stride_bias,
660 batch_stride_randval,
661 batch_stride_lse,
662 batch_stride_o,
663 window_size_left,
664 window_size_right,
665 mask_type,
666 p_drop,
667 s_randval,
668 std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
669 cu_seqlen_q_ptr,
670 cu_seqlen_k_ptr);
671 }
672
673 template <bool Cond = kIsGroupMode>
674 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
675 MakeKargsImpl(const void* q_ptr,
676 const void* k_ptr,
677 const void* v_ptr,
678 const void* bias_ptr,
679 void* rand_val_ptr,
680 void* lse_ptr,
681 void* o_ptr,
682 const void* seqstart_q_ptr,
683 const void* seqstart_k_ptr,
684 const void* seqlen_q_ptr,
685 const void* seqlen_k_ptr,
686 ck_tile::index_t hdim_q,
687 ck_tile::index_t hdim_v,
688 ck_tile::index_t num_head_q,
689 ck_tile::index_t nhead_ratio_qk,
690 float scale_s,
691 float scale_p,
692 float scale_o,
693 float logits_soft_cap,
694 ck_tile::index_t stride_q,
695 ck_tile::index_t stride_k,
696 ck_tile::index_t stride_v,
697 ck_tile::index_t stride_bias,
698 ck_tile::index_t stride_randval,
699 ck_tile::index_t stride_o,
700 ck_tile::index_t nhead_stride_q,
701 ck_tile::index_t nhead_stride_k,
702 ck_tile::index_t nhead_stride_v,
703 ck_tile::index_t nhead_stride_bias,
704 ck_tile::index_t nhead_stride_randval,
705 ck_tile::index_t nhead_stride_lse,
706 ck_tile::index_t nhead_stride_o,
707 ck_tile::index_t window_size_left,
708 ck_tile::index_t window_size_right,
709 ck_tile::index_t mask_type,
710 ck_tile::index_t min_seqlen_q,
711 float p_drop,
712 bool s_randval,
713 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
714 drop_seed_offset,
715 const void* cu_seqlen_q_ptr = nullptr,
716 const void* cu_seqlen_k_ptr = nullptr)
717 {
718 Kargs kargs{{q_ptr,
719 k_ptr,
720 v_ptr,
721 o_ptr,
722 -1, // seqlen will be updated by another pointer
723 -1, //
724 hdim_q,
725 hdim_v,
726 num_head_q,
727 nhead_ratio_qk,
728#if CK_TILE_FMHA_FWD_FAST_EXP2
729 static_cast<float>(scale_s * ck_tile::log2e_v<>),
730#else
731 scale_s,
732#endif
733 stride_q,
734 stride_k,
735 stride_v,
736 stride_o,
737 nhead_stride_q,
738 nhead_stride_k,
739 nhead_stride_v,
740 nhead_stride_o}, // args for common karg
741 {}, // placeholder for bias
742 {}, // placeholder for mask
743 {}, // placeholder for lse
744 {}, // placeholder for fp8_static_quant args
745 {}, // placeholder for dropout
746 {}, // placeholder for logits_soft_cap
747 {}, // placeholder for min_seqlen_q
748 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
749 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
750 reinterpret_cast<const int32_t*>(seqlen_q_ptr),
751 reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
752
754 {
755 kargs.bias_ptr = bias_ptr;
756 kargs.stride_bias = stride_bias;
757 kargs.nhead_stride_bias = nhead_stride_bias;
758 }
759 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
760 {
761 kargs.alibi_slope_ptr = bias_ptr;
762 kargs.alibi_slope_stride = stride_bias;
763 }
764 if constexpr(kHasMask)
765 {
766 kargs.window_size_left = window_size_left;
767 kargs.window_size_right = window_size_right;
768 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
769 }
770 if constexpr(kStoreLSE)
771 {
772 kargs.lse_ptr = lse_ptr;
773 kargs.nhead_stride_lse = nhead_stride_lse;
774 }
775 if constexpr(kDoFp8StaticQuant)
776 {
777 kargs.scale_p = scale_p;
778 kargs.scale_o = scale_o;
779 }
780 if constexpr(kHasDropout)
781 {
782 if(drop_seed_offset.index() == 0) // seed & offset come from host
783 {
784 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
785 kargs.init_dropout(p_drop, seed, offset);
786 }
787 else // seed & offset come from device
788 {
789 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
790 kargs.init_dropout(p_drop,
791 reinterpret_cast<const uint64_t*>(seed_ptr),
792 reinterpret_cast<const uint64_t*>(offset_ptr));
793 }
794
795 kargs.rand_val_ptr = rand_val_ptr;
796 kargs.stride_randval = stride_randval;
797 kargs.nhead_stride_randval = nhead_stride_randval;
798 kargs.is_store_randval = s_randval;
799 }
800 if constexpr(kHasLogitsSoftCap)
801 {
802 kargs.init_logits_soft_cap(logits_soft_cap);
803 }
804 if constexpr(kSkipMinSeqlenQ)
805 {
806 kargs.min_seqlen_q = min_seqlen_q;
807 }
808
809 kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
810 kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
811 return kargs;
812 }
813
814 // std::variant<> can't take in a list initializer, overload for backward compatibility
815 template <bool Cond = kIsGroupMode>
816 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
817 MakeKargs(const void* q_ptr,
818 const void* k_ptr,
819 const void* v_ptr,
820 const void* bias_ptr,
821 void* rand_val_ptr,
822 void* lse_ptr,
823 void* o_ptr,
824 const void* seqstart_q_ptr,
825 const void* seqstart_k_ptr,
826 const void* seqlen_q_ptr,
827 const void* seqlen_k_ptr,
828 ck_tile::index_t hdim_q,
829 ck_tile::index_t hdim_v,
830 ck_tile::index_t num_head_q,
831 ck_tile::index_t nhead_ratio_qk,
832 float scale_s,
833 float scale_p,
834 float scale_o,
835 float logits_soft_cap,
836 ck_tile::index_t stride_q,
837 ck_tile::index_t stride_k,
838 ck_tile::index_t stride_v,
839 ck_tile::index_t stride_bias,
840 ck_tile::index_t stride_randval,
841 ck_tile::index_t stride_o,
842 ck_tile::index_t nhead_stride_q,
843 ck_tile::index_t nhead_stride_k,
844 ck_tile::index_t nhead_stride_v,
845 ck_tile::index_t nhead_stride_bias,
846 ck_tile::index_t nhead_stride_randval,
847 ck_tile::index_t nhead_stride_lse,
848 ck_tile::index_t nhead_stride_o,
849 ck_tile::index_t window_size_left,
850 ck_tile::index_t window_size_right,
851 ck_tile::index_t mask_type,
852 ck_tile::index_t min_seqlen_q,
853 float p_drop,
854 bool s_randval,
855 const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
856 const void* cu_seqlen_q_ptr = nullptr,
857 const void* cu_seqlen_k_ptr = nullptr)
858 {
859 return MakeKargsImpl(
860 q_ptr,
861 k_ptr,
862 v_ptr,
863 bias_ptr,
864 rand_val_ptr,
865 lse_ptr,
866 o_ptr,
867 seqstart_q_ptr,
868 seqstart_k_ptr,
869 seqlen_q_ptr,
870 seqlen_k_ptr,
871 hdim_q,
872 hdim_v,
873 num_head_q,
874 nhead_ratio_qk,
875 scale_s,
876 scale_p,
877 scale_o,
878 logits_soft_cap,
879 stride_q,
880 stride_k,
881 stride_v,
882 stride_bias,
883 stride_randval,
884 stride_o,
885 nhead_stride_q,
886 nhead_stride_k,
887 nhead_stride_v,
888 nhead_stride_bias,
889 nhead_stride_randval,
890 nhead_stride_lse,
891 nhead_stride_o,
892 window_size_left,
893 window_size_right,
894 mask_type,
895 min_seqlen_q,
896 p_drop,
897 s_randval,
898 std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
899 cu_seqlen_q_ptr,
900 cu_seqlen_k_ptr);
901 }
902
903 // std::variant<> can't take in a list initializer, overload for backward compatibility
904 template <bool Cond = kIsGroupMode>
905 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
906 MakeKargs(const void* q_ptr,
907 const void* k_ptr,
908 const void* v_ptr,
909 const void* bias_ptr,
910 void* rand_val_ptr,
911 void* lse_ptr,
912 void* o_ptr,
913 const void* seqstart_q_ptr,
914 const void* seqstart_k_ptr,
915 const void* seqlen_q_ptr,
916 const void* seqlen_k_ptr,
917 ck_tile::index_t hdim_q,
918 ck_tile::index_t hdim_v,
919 ck_tile::index_t num_head_q,
920 ck_tile::index_t nhead_ratio_qk,
921 float scale_s,
922 float scale_p,
923 float scale_o,
924 float logits_soft_cap,
925 ck_tile::index_t stride_q,
926 ck_tile::index_t stride_k,
927 ck_tile::index_t stride_v,
928 ck_tile::index_t stride_bias,
929 ck_tile::index_t stride_randval,
930 ck_tile::index_t stride_o,
931 ck_tile::index_t nhead_stride_q,
932 ck_tile::index_t nhead_stride_k,
933 ck_tile::index_t nhead_stride_v,
934 ck_tile::index_t nhead_stride_bias,
935 ck_tile::index_t nhead_stride_randval,
936 ck_tile::index_t nhead_stride_lse,
937 ck_tile::index_t nhead_stride_o,
938 ck_tile::index_t window_size_left,
939 ck_tile::index_t window_size_right,
940 ck_tile::index_t mask_type,
941 ck_tile::index_t min_seqlen_q,
942 float p_drop,
943 bool s_randval,
944 const std::tuple<const void*, const void*>& drop_seed_offset,
945 const void* cu_seqlen_q_ptr = nullptr,
946 const void* cu_seqlen_k_ptr = nullptr)
947 {
948 return MakeKargsImpl(
949 q_ptr,
950 k_ptr,
951 v_ptr,
952 bias_ptr,
953 rand_val_ptr,
954 lse_ptr,
955 o_ptr,
956 seqstart_q_ptr,
957 seqstart_k_ptr,
958 seqlen_q_ptr,
959 seqlen_k_ptr,
960 hdim_q,
961 hdim_v,
962 num_head_q,
963 nhead_ratio_qk,
964 scale_s,
965 scale_p,
966 scale_o,
967 logits_soft_cap,
968 stride_q,
969 stride_k,
970 stride_v,
971 stride_bias,
972 stride_randval,
973 stride_o,
974 nhead_stride_q,
975 nhead_stride_k,
976 nhead_stride_v,
977 nhead_stride_bias,
978 nhead_stride_randval,
979 nhead_stride_lse,
980 nhead_stride_o,
981 window_size_left,
982 window_size_right,
983 mask_type,
984 min_seqlen_q,
985 p_drop,
986 s_randval,
987 std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
988 cu_seqlen_q_ptr,
989 cu_seqlen_k_ptr);
990 }
991
992 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
993 ck_tile::index_t nhead_,
994 ck_tile::index_t seqlen_q_,
995 ck_tile::index_t hdim_v_,
996 bool has_padded_seqlen_k = false)
997 {
998 // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
999 if(has_padded_seqlen_k)
1000 {
1001 // TODO: this may need tuning
1002 return dim3(nhead_,
1003 batch_size_,
1004 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1005 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
1006 }
1007 else
1008 {
1009 // TODO: this may need tuning
1010 return dim3(nhead_,
1011 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1012 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
1013 batch_size_);
1014 }
1015 }
1016
1017 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1018 {
1019 bool has_padded_seqlen_k = false;
1020
1021 if constexpr(kIsGroupMode)
1022 has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1023
1024 if(has_padded_seqlen_k)
1025 {
1026 // const index_t num_tile_m0 = seqlen_q / kM0;
1027 const index_t num_tile_n1 =
1028 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1029
1030 const index_t i_block = blockIdx.z;
1031 const index_t i_nhead = blockIdx.x;
1032 const index_t i_batch = blockIdx.y;
1033
1034 const auto f = [](index_t dividend, index_t divisor) {
1035 index_t quotient = dividend / divisor;
1036 index_t modulus = dividend - quotient * divisor;
1037 return ck_tile::make_tuple(quotient, modulus);
1038 };
1039
1040 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1041
1042 if constexpr(kHasMask)
1043 {
1044 // assume that num_tile_n1 is always 1
1045 return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1046 }
1047 else
1048 {
1049 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1050 }
1051 }
1052 else
1053 {
1054 // const index_t num_tile_m0 = seqlen_q / kM0;
1055 const index_t num_tile_n1 =
1056 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1057
1058 const index_t i_block = blockIdx.y; // blockIdx.x
1059 const index_t i_nhead = blockIdx.x; // blockIdx.y
1060 const index_t i_batch = blockIdx.z;
1061
1062 const auto f = [](index_t dividend, index_t divisor) {
1063 index_t quotient = dividend / divisor;
1064 index_t modulus = dividend - quotient * divisor;
1065 return ck_tile::make_tuple(quotient, modulus);
1066 };
1067
1068 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1069
1070 if constexpr(kHasMask)
1071 {
1072 // assume that num_tile_n1 is always 1
1073 return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1074 }
1075 else
1076 {
1077 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1078 }
1079 }
1080 }
1081
1083 {
1084 if(is_wave32())
1085 {
1086 return dim3(kBlockSize / 2);
1087 }
1088 else
1089 {
1090 return dim3(kBlockSize);
1091 }
1092 }
1093
1095 {
1096 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1097 }
1098
1100 {
1101 if constexpr(kIsAvailable)
1102 run_(std::move(kargs));
1103 }
1104
1105 CK_TILE_DEVICE void run_(Kargs kargs) const
1106 {
1107 if constexpr(kPipelineName != "qr_async_trload")
1108 {
1109 // allocate LDS
1110 __shared__ char smem_ptr[GetSmemSize()];
1111
1112 // divide problem
1113 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1114
1115 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1116 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1117
1118 long_index_t batch_offset_q = 0;
1119 long_index_t batch_offset_k = 0;
1120 long_index_t batch_offset_v = 0;
1121 long_index_t batch_offset_bias = 0;
1122 long_index_t batch_offset_randval = 0;
1123 long_index_t batch_offset_lse = 0;
1124 long_index_t batch_offset_o = 0;
1125
1126 if constexpr(kIsGroupMode)
1127 {
1128 // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1129 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1130 const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1131
1132 // DRAM base offsets use physical starts
1133 batch_offset_q = query_start * kargs.stride_q;
1134 batch_offset_k = key_start * kargs.stride_k;
1135 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1136 {
1137 batch_offset_v = key_start * kargs.stride_v;
1138 }
1139 else
1140 {
1141 batch_offset_v = key_start;
1142 }
1144 {
1145 batch_offset_bias = query_start * kargs.stride_bias;
1146 }
1147 if constexpr(kStoreLSE)
1148 {
1149 // LSE follows the physical layout to stay consistent with other tensors
1150 batch_offset_lse = query_start;
1151 }
1152 if constexpr(kHasDropout)
1153 {
1154 batch_offset_randval = query_start * kargs.stride_randval;
1155 }
1156 batch_offset_o = query_start * kargs.stride_o;
1157
1158 // real logical lengths (exclude PAD)
1159 // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1160 if(kargs.seqlen_q_ptr != nullptr)
1161 {
1162 kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1163 }
1164 else if(kargs.cu_seqlen_q_ptr != nullptr)
1165 {
1166 kargs.seqlen_q =
1167 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1168 }
1169 else
1170 {
1171 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1172 kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1173 }
1174
1175 if constexpr(kSkipMinSeqlenQ)
1176 {
1177 if(kargs.seqlen_q <= kargs.min_seqlen_q)
1178 {
1179 return;
1180 }
1181 }
1182
1183 // terminate unnecessary blocks earlier
1184 if(kargs.seqlen_q <= i_m0)
1185 {
1186 return;
1187 }
1188
1189 if(kargs.seqlen_k_ptr != nullptr)
1190 {
1191 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1192 }
1193 else if(kargs.cu_seqlen_k_ptr != nullptr)
1194 {
1195 kargs.seqlen_k =
1196 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1197 }
1198 else
1199 {
1200 const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1201 kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1202 }
1203 }
1204 else
1205 {
1206 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1207 batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1208 batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1210 {
1211 batch_offset_bias =
1212 static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1213 }
1214 if constexpr(kStoreLSE)
1215 {
1216 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1217 }
1218 if constexpr(kHasDropout)
1219 {
1220 batch_offset_randval =
1221 static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1222 }
1223 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1224
1225 // If cumulative seqlen pointers are provided, override per-batch effective lengths
1226 if(kargs.cu_seqlen_q_ptr != nullptr)
1227 {
1228 kargs.seqlen_q =
1229 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1230 }
1231 if(kargs.cu_seqlen_k_ptr != nullptr)
1232 {
1233 kargs.seqlen_k =
1234 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1235 }
1236 }
1237
1238 // for simplicity, batch stride we just modify the pointer
1239 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1240 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1241 batch_offset_q;
1242 const KDataType* k_ptr =
1243 reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1244 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1245 batch_offset_k;
1246 const VDataType* v_ptr =
1247 reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1248 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1249 batch_offset_v;
1250 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1251 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1252 batch_offset_o;
1253
1254 // Q/K/V DRAM and DRAM window
1255 const auto q_dram = [&]() {
1257 q_ptr,
1258 make_tuple(kargs.seqlen_q, kargs.hdim_q),
1259 make_tuple(kargs.stride_q, 1),
1261 number<1>{});
1262 if constexpr(FmhaPipeline::kQLoadOnce)
1263 {
1264 return pad_tensor_view(q_dram_naive,
1268 }
1269 else
1270 {
1271 return pad_tensor_view(
1272 q_dram_naive,
1275 }
1276 }();
1277 const auto k_dram = [&]() {
1279 k_ptr,
1280 make_tuple(kargs.seqlen_k, kargs.hdim_q),
1281 make_tuple(kargs.stride_k, 1),
1283 number<1>{});
1284
1285 constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1286 return pad_tensor_view(
1287 k_dram_naive,
1290 }();
1291 const auto v_dram = [&]() {
1292 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1293 {
1295 v_ptr,
1296 make_tuple(kargs.seqlen_k, kargs.hdim_v),
1297 make_tuple(kargs.stride_v, 1),
1299 number<1>{});
1300
1301 const auto v_dram_transposed = transform_tensor_view(
1302 v_dram_naive,
1304 make_pass_through_transform(kargs.seqlen_k)),
1307
1308 constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1309 return pad_tensor_view(
1310 v_dram_transposed,
1313 }
1314 else
1315 {
1317 v_ptr,
1318 make_tuple(kargs.hdim_v, kargs.seqlen_k),
1319 make_tuple(kargs.stride_v, 1),
1321 number<1>{});
1322
1323 constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1324 return pad_tensor_view(
1325 v_dram_naive,
1328 }
1329 }();
1330
1331 auto q_dram_window = make_tile_window(
1332 q_dram,
1333 [&]() {
1334 if constexpr(FmhaPipeline::kQLoadOnce)
1337 else
1339 }(),
1340 {i_m0, 0});
1341
1342 auto k_dram_window = make_tile_window(
1343 k_dram,
1345 {0, 0});
1346
1347 auto v_dram_window = make_tile_window(
1348 v_dram,
1350 {i_n1, 0});
1353 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1354 constexpr auto bias_dram_window_lengths =
1357 {
1358 const BiasDataType* bias_ptr =
1359 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1360 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1361 batch_offset_bias;
1362
1363 const auto bias_dram = [&]() {
1364 const auto bias_dram_naive =
1366 bias_ptr,
1367 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1368 make_tuple(kargs.stride_bias, 1),
1370 number<1>{});
1371
1372 return pad_tensor_view(bias_dram_naive,
1373 bias_dram_window_lengths,
1375 }();
1376
1377 return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1378 }
1379 else
1380 {
1381 return make_null_tile_window(bias_dram_window_lengths);
1382 }
1383 }();
1384
1385 // lse
1386 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1387 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1388 if constexpr(kStoreLSE)
1389 {
1390 LSEDataType* lse_ptr =
1391 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1392 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1393 batch_offset_lse;
1394
1395 const auto lse_dram = [&]() {
1396 const auto lse_dram_naive =
1398 lse_ptr,
1399 make_tuple(kargs.seqlen_q),
1400 make_tuple(1),
1401 number<1>{},
1402 number<1>{});
1403
1404 return pad_tensor_view(
1405 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1406 }();
1407
1408 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1409 }
1410 else
1411 {
1412 return make_null_tile_window(lse_dram_window_lengths);
1413 }
1414 }();
1415
1416 auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1417 if constexpr(kHasDropout)
1418 {
1419 return BlockDropout{i_batch_,
1420 i_nhead_,
1421 kargs.num_head_q,
1422 kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1423 : *kargs.drop_seed.ptr,
1424 kargs.is_drop_seed_offset_from_host
1425 ? kargs.drop_offset.val
1426 : *kargs.drop_offset.ptr,
1427 kargs.rp_undrop,
1428 kargs.p_undrop_in_uint8_t,
1429 kargs.is_store_randval};
1430 }
1431 else
1432 {
1433 return NullBlockDropout{};
1434 };
1435 }();
1436
1437 auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1438 constexpr auto randval_dram_window_lengths =
1440 if constexpr(kHasDropout)
1441 {
1442 RandValOutputDataType* rand_val_ptr =
1443 reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1444 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1445 batch_offset_randval;
1446
1447 const auto randval_dram = [&]() {
1448 const auto randval_dram_naive =
1450 rand_val_ptr,
1451 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1452 make_tuple(kargs.stride_randval, 1),
1453 number<1>{},
1454 number<1>{});
1455
1456 return pad_tensor_view(randval_dram_naive,
1457 randval_dram_window_lengths,
1459 }();
1460
1461 return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1462 }
1463 else
1464 {
1465 return make_null_tile_window(randval_dram_window_lengths);
1466 }
1467 }();
1468
1469 FmhaMask mask = [&]() {
1470 if constexpr(kHasMask)
1472 kargs.window_size_left,
1473 kargs.window_size_right,
1474 kargs.seqlen_q,
1475 kargs.seqlen_k,
1477 else
1478 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1479 }();
1480
1481 // WA i_batch capture structure binding before c++20
1482 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1484 {
1485 // data loading, shared by entire wg
1486 // TODO: how to use s_read?
1487 SaccDataType slope =
1488 *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1489 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1490#if CK_TILE_FMHA_FWD_FAST_EXP2
1491 slope *= ck_tile::log2e_v<>;
1492#endif
1493 if constexpr(kHasMask)
1494 {
1496 kargs.window_size_left,
1497 kargs.window_size_right,
1498 kargs.seqlen_q,
1499 kargs.seqlen_k,
1500 kargs.mask_type);
1501 }
1502 else
1503 {
1505 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1506 }
1507 }
1508 else
1509 {
1511 }
1512 }();
1513
1514 AttentionVariant variant;
1515 const auto variant_params = [&] {
1516 if constexpr(kHasLogitsSoftCap)
1517 {
1519 mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1520 }
1521 else
1522 {
1523 return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1524 }
1525 }();
1526
1527 BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1528
1529 auto o_acc_tile = [&]() {
1530 if constexpr(kDoFp8StaticQuant)
1531 {
1532 auto o_acc_element_func = [&]() {
1533 if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1535 ck_tile::scales{kargs.scale_o});
1536 else
1537 return ck_tile::scales{kargs.scale_o};
1538 }();
1539 return FmhaPipeline{}(q_dram_window,
1540 identity{}, // q_element_func
1541 k_dram_window,
1542 identity{}, // k_element_func
1543 v_dram_window,
1544 identity{}, // v_element_func
1545 bias_dram_window,
1546 identity{}, // bias_element_func
1547 randval_dram_window,
1548 lse_dram_window,
1549 identity{}, // lse_element_func
1550 identity{}, // s_acc_element_func
1551 scales{kargs.scale_p}, // p_compute_element_func
1552 o_acc_element_func, // o_acc_element_func
1553 mask,
1554 position_encoding,
1555 kargs.scale_s,
1556 variant,
1557 variant_params,
1558 block_indices,
1559 smem_ptr,
1560 dropout);
1561 }
1562 else
1563 {
1564 return FmhaPipeline{}(q_dram_window,
1565 k_dram_window,
1566 v_dram_window,
1567 bias_dram_window,
1568 randval_dram_window,
1569 lse_dram_window,
1570 mask,
1571 position_encoding,
1572 kargs.scale_s,
1573 variant,
1574 variant_params,
1575 block_indices,
1576 smem_ptr,
1577 dropout);
1578 }
1579 }();
1580
1581 // O DRAM and O DRAM window
1582 auto o_dram = [&]() {
1584 o_ptr,
1585 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1586 make_tuple(kargs.stride_o, 1),
1588 number<1>{});
1589
1590 return pad_tensor_view(
1591 o_dram_naive,
1594 }();
1595
1596 auto o_dram_window = make_tile_window(
1597 o_dram,
1599 {i_m0, i_n1});
1600
1601 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1602 }
1603 else
1604 {
1605 // TODO: Refine the logical here.
1606 // In Decode case
1607 // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1608 // 2. limit the LDS usage, as we want higher occupancy
1609 // In Prefill case
1610 // 1. we expect KV data reused by different ThreadGroups, use cache
1611 // 2. use more LDS, as we want better memory latency hiding
1612 // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1613 // cache
1614 constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
1615 // divide problem
1616 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1617
1618 const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1619 const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1620
1621 long_index_t batch_offset_q = 0;
1622 long_index_t batch_offset_k = 0; // unused for paged-kvcache
1623 long_index_t batch_offset_v = 0; // unused for paged-kvcache
1624 long_index_t batch_offset_bias = 0;
1625 long_index_t batch_offset_lse = 0;
1626 long_index_t batch_offset_o = 0;
1627 // index_t kv_l2p_offset =
1628 // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1629 // paged-kvcache
1630
1631 if constexpr(kIsGroupMode)
1632 {
1633 // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1634 // physical starts
1635 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1636 const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1637
1638 batch_offset_q = query_start * kargs.stride_q;
1639 batch_offset_k = key_start * kargs.stride_k;
1640 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1641 {
1642 batch_offset_v = key_start * kargs.stride_v;
1643 }
1644 else
1645 {
1646 // col-major V: offset along seqlen dimension is scalar index
1647 batch_offset_v = key_start;
1648 }
1650 {
1651 batch_offset_bias = query_start * kargs.stride_bias;
1652 }
1653
1654 // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1655 batch_offset_lse = query_start;
1656 batch_offset_o = query_start * kargs.stride_o;
1657
1658 // get real # queries & # keys under group mode
1659 if(kargs.seqlen_q_ptr != nullptr)
1660 {
1661 kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1662 }
1663 else if(kargs.cu_seqlen_q_ptr != nullptr)
1664 {
1665 kargs.seqlen_q =
1666 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1667 }
1668 else
1669 {
1670 kargs.seqlen_q =
1671 kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1672 }
1673
1674 // # of required blocks is different in each groups, terminate unnecessary blocks
1675 // earlier
1676 if(kargs.seqlen_q <= i_m0)
1677 {
1678 return;
1679 }
1680
1681 if(kargs.seqlen_k_ptr != nullptr)
1682 {
1683 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1684 }
1685 else if(kargs.cu_seqlen_k_ptr != nullptr)
1686 {
1687 kargs.seqlen_k =
1688 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1689 }
1690 else
1691 {
1692 kargs.seqlen_k =
1693 kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1694 }
1695 }
1696 else
1697 {
1698 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1699 batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1700 batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1701 if constexpr(kStoreLSE)
1702 {
1703 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1704 }
1705 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1706
1708 {
1709 batch_offset_bias =
1710 static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1711 }
1712
1713 // If cumulative seqlen pointers are provided, override per-batch effective lengths
1714 if(kargs.cu_seqlen_q_ptr != nullptr)
1715 {
1716 kargs.seqlen_q =
1717 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1718 }
1719 if(kargs.cu_seqlen_k_ptr != nullptr)
1720 {
1721 kargs.seqlen_k =
1722 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1723 }
1724 }
1725
1726 // for simplicity, batch stride we just modify the pointer
1727 const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1728
1729 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1730 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1731 batch_offset_q;
1732 const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1733 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1734 batch_offset_k;
1735 const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1736 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1737 batch_offset_v;
1738
1739 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1740 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1741 batch_offset_o;
1742
1743 // Q/K/V DRAM and DRAM window
1744 const auto q_dram = [&] {
1745 const auto q_dram_naive = [&] {
1746 {
1750 q_ptr,
1751 make_tuple(kargs.seqlen_q, kargs.hdim_q),
1752 make_tuple(kargs.stride_q, 1),
1754 number<1>{});
1755 }
1756 }();
1757
1758 if constexpr(FmhaPipeline::kQLoadOnce)
1759 {
1760 const auto seqlen_q = kargs.seqlen_q;
1761 const auto q_dram_pad = pad_tensor_view(
1762 q_dram_naive,
1765#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1766 constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1767 constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1768
1769 if constexpr(XorLengthFold > 1)
1770 {
1771 const auto q_dram_unmerged = transform_tensor_view(
1772 q_dram_pad,
1773 make_tuple(
1775 make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1779
1780 const auto q_dram_merged = transform_tensor_view(
1781 q_dram_unmerged,
1782 make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1784 XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1787
1788 const auto q_dram_unmerged_xor = transform_tensor_view(
1789 q_dram_merged,
1790 make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1796
1797 const auto q_dram_permuted = transform_tensor_view(
1798 q_dram_unmerged_xor,
1799 make_tuple(
1801 make_tuple(seqlen_q / XorLengthFold,
1806
1807 const auto q_dram_tmp = transform_tensor_view(
1808 q_dram_permuted,
1809 make_tuple(
1810 make_pass_through_transform(seqlen_q / XorLengthFold),
1813 number<FmhaPipeline::kQKHeaddim /
1814 FmhaPipeline::kAlignmentQ>{})),
1818
1819 return transform_tensor_view(
1820 q_dram_tmp,
1821 make_tuple(
1823 make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1829 }
1830 else
1831#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1832 {
1833 const auto q_dram_unmerged = transform_tensor_view(
1834 q_dram_pad,
1835 make_tuple(
1842
1843 const auto q_dram_permuted = transform_tensor_view(
1844 q_dram_unmerged,
1845 make_tuple(
1847 number<FmhaPipeline::kQKHeaddim /
1848 FmhaPipeline::kAlignmentQ>{})),
1852
1853 return transform_tensor_view(
1854 q_dram_permuted,
1855 make_tuple(
1862 }
1863 }
1864 else
1865 {
1866 return pad_tensor_view(
1867 q_dram_naive,
1870 }
1871 }();
1872
1873 const auto make_k_dram = [&](const KDataType* data, index_t height) {
1875 data, // will update this pointer if using paged-kvcache
1876 make_tuple(height, kargs.hdim_q),
1877 make_tuple(kargs.stride_k, 1),
1879 number<1>{});
1880
1881 const auto k_dram_pad = pad_tensor_view(
1882 k_dram_naive,
1885
1886 constexpr auto kDramTileK =
1887 FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1888
1889#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1890 constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1891 constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1892
1893 if constexpr(XorLengthFold > 1)
1894 {
1895 const auto k_dram_unmerged = transform_tensor_view(
1896 k_dram_pad,
1898 make_tuple(height / XorLengthFold, XorLengthFold)),
1902
1903 const auto k_dram_merged = transform_tensor_view(
1904 k_dram_unmerged,
1905 make_tuple(make_pass_through_transform(height / XorLengthFold),
1907 XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1910
1911 const auto k_dram_unmerged_xor = transform_tensor_view(
1912 k_dram_merged,
1913 make_tuple(make_pass_through_transform(height / XorLengthFold),
1919
1920 const auto k_dram_permuted = transform_tensor_view(
1921 k_dram_unmerged_xor,
1922 make_tuple(
1924 make_tuple(height / XorLengthFold,
1929
1930 const auto k_dram_tmp = transform_tensor_view(
1931 k_dram_permuted,
1932 make_tuple(
1933 make_pass_through_transform(height / XorLengthFold),
1936 number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1940
1941 return transform_tensor_view(
1942 k_dram_tmp,
1943 make_tuple(
1945 make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1951 }
1952 else
1953#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1954 {
1955 const auto k_dram_unmerged = transform_tensor_view(
1956 k_dram_pad,
1959 make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1960 FmhaPipeline::kAlignmentK>{},
1961 number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1965
1966 const auto k_dram_permuted = transform_tensor_view(
1967 k_dram_unmerged,
1968 make_tuple(
1972 number<FmhaPipeline::kQKHeaddim / kDramTileK /
1973 FmhaPipeline::kAlignmentK>{}),
1977
1978 return transform_tensor_view(
1979 k_dram_permuted,
1982 make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1983 FmhaPipeline::kAlignmentK>{},
1984 number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1988 }
1989 };
1990 const auto k_dram = [&]() {
1991 {
1992 return make_k_dram(k_ptr, kargs.seqlen_k);
1993 }
1994 }();
1995
1996 const auto make_v_dram = [&](const VDataType* data, index_t length) {
1998 data, // will update this pointer if using paged-kvcache
1999 make_tuple(length, kargs.hdim_v),
2000 make_tuple(kargs.stride_v, 1),
2002 number<1>{});
2003
2004 // TODO: Add kVHeadDim
2005 constexpr index_t XorGroupSize =
2006 FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
2007
2008 const auto v_dram_pad = pad_tensor_view(
2009 v_dram_naive,
2012
2013#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2014 constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
2015 constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2016
2017 if constexpr(XorLengthFold > 1)
2018 {
2019 const auto v_dram_unmerged = transform_tensor_view(
2020 v_dram_pad,
2022 make_tuple(length / XorLengthFold, XorLengthFold)),
2026
2027 const auto v_dram_merged = transform_tensor_view(
2028 v_dram_unmerged,
2029 make_tuple(make_pass_through_transform(length / XorLengthFold),
2031 XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2034
2035 const auto v_dram_unmerged_xor = transform_tensor_view(
2036 v_dram_merged,
2037 make_tuple(
2038 make_pass_through_transform(length / XorLengthFold),
2043
2044 const auto v_dram_permuted = transform_tensor_view(
2045 v_dram_unmerged_xor,
2046 make_tuple(
2047 make_xor_transform(make_tuple(length / XorLengthFold,
2052
2053 const auto v_dram_tmp = transform_tensor_view(
2054 v_dram_permuted,
2055 make_tuple(make_pass_through_transform(length / XorLengthFold),
2058 number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2062
2063 return transform_tensor_view(
2064 v_dram_tmp,
2066 make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2072 }
2073 else
2074#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2075 {
2076 const auto v_dram_unmerged = transform_tensor_view(
2077 v_dram_pad,
2084
2085 const auto v_dram_permuted = transform_tensor_view(
2086 v_dram_unmerged,
2092
2093 return transform_tensor_view(
2094 v_dram_permuted,
2101 }
2102 };
2103
2104 const auto v_dram = [&]() {
2105 {
2106 return make_v_dram(v_ptr, kargs.seqlen_k);
2107 }
2108 }();
2109
2110 auto q_dram_window = make_tile_window(
2111 q_dram,
2112 [&]() {
2113 if constexpr(FmhaPipeline::kQLoadOnce)
2116 else
2118 }(),
2119 {i_m0, 0});
2120
2121 auto k_dram_window = make_tile_window(
2122 k_dram,
2124 {0, 0});
2125
2126 auto v_dram_window = make_tile_window(
2127 v_dram,
2129 {0, 0});
2130
2133 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2134 constexpr auto bias_dram_window_lengths =
2137 {
2138 const BiasDataType* bias_ptr =
2139 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2140 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2141 batch_offset_bias;
2142
2143 const auto bias_dram = [&]() {
2144 const auto bias_dram_naive =
2146 bias_ptr,
2147 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2148 make_tuple(kargs.stride_bias, 1),
2150 number<1>{});
2151
2152 return pad_tensor_view(bias_dram_naive,
2153 bias_dram_window_lengths,
2155 }();
2156
2157 return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2158 }
2159 else
2160 {
2161 return make_null_tile_window(bias_dram_window_lengths);
2162 }
2163 }();
2164
2165 // lse acc
2166 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2167 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2168 if constexpr(kStoreLSE)
2169 {
2170 LSEDataType* lse_ptr =
2171 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2172 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2173 batch_offset_lse;
2174
2175 const auto lse_dram = [&] {
2176 const auto lse_dram_naive = [&] {
2177 {
2179 lse_ptr,
2180 make_tuple(kargs.seqlen_q),
2181 make_tuple(1),
2182 number<1>{},
2183 number<1>{});
2184 }
2185 }();
2186 return pad_tensor_view(
2187 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2188 }();
2189
2190 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2191 }
2192 else
2193 {
2194 return make_null_tile_window(lse_dram_window_lengths);
2195 }
2196 }();
2197
2198 FmhaMask mask = [&]() {
2199 if constexpr(kHasMask)
2201 kargs.window_size_left,
2202 kargs.window_size_right,
2203 kargs.seqlen_q,
2204 kargs.seqlen_k,
2206 else
2207 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2208 }();
2209
2210 // WA i_batch capture structure binding before c++20
2211 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2213 {
2214 // data loading, shared by entire wg
2215 // TODO: how to use s_read?
2216 SaccDataType slope =
2217 *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2218 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2219#if CK_TILE_FMHA_FWD_FAST_EXP2
2220 slope *= ck_tile::log2e_v<>;
2221#endif
2222 if constexpr(kHasMask)
2223 {
2225 slope,
2226 kargs.window_size_left,
2227 kargs.window_size_right,
2228 kargs.seqlen_q,
2229 kargs.seqlen_k,
2230 kargs.mask_type);
2231 }
2232 else
2233 {
2235 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2236 }
2237 }
2238 else
2239 {
2241 }
2242 }();
2243
2244 auto o_acc_tile = [&]() {
2245 if constexpr(PrefillCase)
2246 {
2247 // allocate double lds
2248 // add __restrict__ here to avoid aliasing
2249 __shared__ char smem_ptrk0
2250 [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2251 true>()];
2252 __shared__ char smem_ptrk1
2253 [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2254 true>()];
2255 __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2256 typename FmhaPipeline::Problem>()];
2257 __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2258 typename FmhaPipeline::Problem>()];
2259
2260 return FmhaPipeline{}(q_dram_window,
2261 k_dram_window,
2262 v_dram_window,
2263 bias_dram_window,
2264 lse_dram_window,
2265 mask,
2266 position_encoding,
2267 kargs.scale_s,
2268 smem_ptrk0,
2269 smem_ptrk1,
2270 smem_ptrv0,
2271 smem_ptrv1);
2272 }
2273 else
2274 {
2275 __shared__ char smem_ptr[GetSmemSize()];
2276 return FmhaPipeline{}(q_dram_window,
2277 k_dram_window,
2278 v_dram_window,
2279 bias_dram_window,
2280 lse_dram_window,
2281 mask,
2282 position_encoding,
2283 kargs.scale_s,
2284 smem_ptr);
2285 }
2286 }();
2287
2288 // Oacc DRAM and Oacc DRAM window
2289 auto o_dram = [&] {
2290 const auto o_dram_naive = [&] {
2291 {
2293 o_ptr,
2294 make_tuple(kargs.seqlen_q, kargs.hdim_v),
2295 make_tuple(kargs.stride_o, 1),
2297 number<1>{});
2298 }
2299 }();
2300
2301 return pad_tensor_view(
2302 o_dram_naive,
2305 }();
2306
2307 auto o_dram_window = make_tile_window(
2308 o_dram,
2310 {i_m0, i_n1});
2311
2312 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2313 }
2314 }
2315};
2316
2317} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ set
Definition arch.hpp:57
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
@ SYSTEM_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1419
_BitInt(8) fp8_t
Definition float8.hpp:204
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
@ global
Definition arch.hpp:48
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
float fp32_t
Definition pk_fp4.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_dropout.hpp:53
Definition block_position_encoding.hpp:137
Definition fmha_fwd_kernel.hpp:330
ck_tile::index_t kv_head_idx
Definition fmha_fwd_kernel.hpp:333
ck_tile::index_t batch_idx
Definition fmha_fwd_kernel.hpp:331
ck_tile::index_t qo_head_idx
Definition fmha_fwd_kernel.hpp:332
Definition fmha_fwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition fmha_fwd_kernel.hpp:197
const void * alibi_slope_ptr
Definition fmha_fwd_kernel.hpp:196
Definition fmha_fwd_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition fmha_fwd_kernel.hpp:190
ck_tile::index_t batch_stride_randval
Definition fmha_fwd_kernel.hpp:271
Definition fmha_fwd_kernel.hpp:291
ck_tile::index_t batch_stride_o
Definition fmha_fwd_kernel.hpp:295
const int32_t * cu_seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:300
ck_tile::index_t batch_stride_q
Definition fmha_fwd_kernel.hpp:292
ck_tile::index_t batch_stride_k
Definition fmha_fwd_kernel.hpp:293
const int32_t * cu_seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:299
ck_tile::index_t batch_stride_v
Definition fmha_fwd_kernel.hpp:294
Definition fmha_fwd_kernel.hpp:182
const void * bias_ptr
Definition fmha_fwd_kernel.hpp:183
ck_tile::index_t stride_bias
Definition fmha_fwd_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition fmha_fwd_kernel.hpp:185
Definition fmha_fwd_kernel.hpp:235
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition fmha_fwd_kernel.hpp:248
float rp_undrop
Definition fmha_fwd_kernel.hpp:260
ck_tile::index_t stride_randval
Definition fmha_fwd_kernel.hpp:265
ck_tile::index_t nhead_stride_randval
Definition fmha_fwd_kernel.hpp:266
void * rand_val_ptr
Definition fmha_fwd_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition fmha_fwd_kernel.hpp:236
bool is_store_randval
Definition fmha_fwd_kernel.hpp:262
uint8_t p_undrop_in_uint8_t
Definition fmha_fwd_kernel.hpp:261
Definition fmha_fwd_kernel.hpp:131
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_kernel.hpp:154
float scale_s
Definition fmha_fwd_kernel.hpp:146
ck_tile::index_t seqlen_k
Definition fmha_fwd_kernel.hpp:138
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_kernel.hpp:156
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_kernel.hpp:145
ck_tile::index_t num_head_q
Definition fmha_fwd_kernel.hpp:142
ck_tile::index_t hdim_q
Definition fmha_fwd_kernel.hpp:139
const void * v_ptr
Definition fmha_fwd_kernel.hpp:134
void * o_ptr
Definition fmha_fwd_kernel.hpp:135
const void * k_ptr
Definition fmha_fwd_kernel.hpp:133
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_kernel.hpp:153
ck_tile::index_t stride_k
Definition fmha_fwd_kernel.hpp:149
ck_tile::index_t stride_o
Definition fmha_fwd_kernel.hpp:151
ck_tile::index_t stride_v
Definition fmha_fwd_kernel.hpp:150
ck_tile::index_t hdim_v
Definition fmha_fwd_kernel.hpp:140
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_kernel.hpp:155
const void * q_ptr
Definition fmha_fwd_kernel.hpp:132
ck_tile::index_t seqlen_q
Definition fmha_fwd_kernel.hpp:137
ck_tile::index_t stride_q
Definition fmha_fwd_kernel.hpp:148
Definition fmha_fwd_kernel.hpp:214
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_kernel.hpp:217
void * lse_ptr
Definition fmha_fwd_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_kernel.hpp:216
Definition fmha_fwd_kernel.hpp:221
bool is_drop_seed_offset_from_host
Definition fmha_fwd_kernel.hpp:231
ValueOrPointer< uint64_t > drop_seed
Definition fmha_fwd_kernel.hpp:229
ValueOrPointer< uint64_t > drop_offset
Definition fmha_fwd_kernel.hpp:230
Definition fmha_fwd_kernel.hpp:124
Definition fmha_fwd_kernel.hpp:208
float scale_o
Definition fmha_fwd_kernel.hpp:210
float scale_p
Definition fmha_fwd_kernel.hpp:209
Definition fmha_fwd_kernel.hpp:316
const int32_t * seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:319
const int32_t * seqstart_q_ptr
Definition fmha_fwd_kernel.hpp:317
const int32_t * seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:320
const int32_t * cu_seqlen_k_ptr
Definition fmha_fwd_kernel.hpp:324
const int32_t * cu_seqlen_q_ptr
Definition fmha_fwd_kernel.hpp:323
const int32_t * seqstart_k_ptr
Definition fmha_fwd_kernel.hpp:318
float logits_soft_cap
Definition fmha_fwd_kernel.hpp:177
float logits_soft_cap_rcp
Definition fmha_fwd_kernel.hpp:178
void init_logits_soft_cap(float logits_soft_cap_)
Definition fmha_fwd_kernel.hpp:163
Definition fmha_fwd_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_fwd_kernel.hpp:204
ck_tile::index_t window_size_right
Definition fmha_fwd_kernel.hpp:203
ck_tile::index_t window_size_left
Definition fmha_fwd_kernel.hpp:203
Definition fmha_fwd_kernel.hpp:275
ck_tile::index_t min_seqlen_q
Definition fmha_fwd_kernel.hpp:276
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:78
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:80
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:77
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:81
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:82
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:79
static constexpr const char * name
Definition fmha_fwd_kernel.hpp:76
Definition fmha_fwd_kernel.hpp:75
Definition fmha_fwd_kernel.hpp:27
static constexpr bool kHasDropout
Definition fmha_fwd_kernel.hpp:56
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_kernel.hpp:85
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition fmha_fwd_kernel.hpp:40
static constexpr bool kIsAvailable
Definition fmha_fwd_kernel.hpp:70
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_kernel.hpp:57
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:579
static constexpr bool kStoreLSE
Definition fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_kernel.hpp:37
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition fmha_fwd_kernel.hpp:327
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_kernel.hpp:32
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:675
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:482
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_fwd_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_kernel.hpp:38
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_kernel.hpp:34
static constexpr auto BiasEnum
Definition fmha_fwd_kernel.hpp:54
static constexpr bool kPadHeadDimV
Definition fmha_fwd_kernel.hpp:52
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:817
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_kernel.hpp:1017
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:906
static constexpr bool kSkipMinSeqlenQ
Definition fmha_fwd_kernel.hpp:58
static constexpr std::string_view kPipelineName
Definition fmha_fwd_kernel.hpp:72
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_kernel.hpp:1094
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_kernel.hpp:36
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_kernel.hpp:1082
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition fmha_fwd_kernel.hpp:1105
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition fmha_fwd_kernel.hpp:60
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition fmha_fwd_kernel.hpp:992
static constexpr bool kUseTrLoad
Definition fmha_fwd_kernel.hpp:66
static constexpr bool kHasMask
Definition fmha_fwd_kernel.hpp:62
static constexpr bool kUseAsyncCopy
Definition fmha_fwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_kernel.hpp:28
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_fwd_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_fwd_kernel.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_kernel.hpp:29
static constexpr bool kPadSeqLenK
Definition fmha_fwd_kernel.hpp:50
static constexpr bool kIsGroupMode
Definition fmha_fwd_kernel.hpp:48
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition fmha_fwd_kernel.hpp:338
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_kernel.hpp:1099
Definition variants.hpp:63
Definition block_dropout.hpp:39
Definition variants.hpp:51
Definition unary_element_function.hpp:12
Definition tile/core/utility/functional.hpp:86
Definition coordinate_transform.hpp:1392
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49
const T * ptr
Definition fmha_fwd_kernel.hpp:226