blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp Source File

blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp Source File
blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::A_K1;
120 using Base::I0;
121 using Base::I1;
122 using Base::KRepeat;
123 using Base::MWaves;
124 using Base::NWaves;
125 using Base::WaveSize;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
137 using Base::GetWaveIdx;
140
143
144 using Base::AMmaKStride;
145 using Base::APackedSize;
146 using Base::BMmaKStride;
147 using Base::BPackedSize;
148 using Base::KThreadChunk;
149
150 using Base::KXdlPack;
151 using Base::MXdlPack;
152 using Base::NXdlPack;
153
154 using AccType = typename Base::AccType;
155 using Tuple5 = typename Base::Tuple5;
158
159 static constexpr index_t PrefetchStages = 2;
160 static constexpr index_t LocalPrefetchStages = 2;
161 static constexpr index_t PrefillStages = 1;
162 static constexpr index_t GlobalBufferNum = 1;
163 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
164
165 static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
166 static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
167 static constexpr auto async_vmcnt =
169 static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
170
171 static constexpr auto ScalesPerKBlockSize =
172 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
173
174 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
175 static constexpr auto ScalesPerXdlopsRun =
176 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
177
178 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
179 static constexpr auto ScalesPerXdlopsRunPerThread =
180 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
181
183 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
184 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
185 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
186 "A scale pack data type too large!");
187 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
188 "B scale pack data type too large!");
191
192 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
193 {
194 return num_loop > PrefetchStages;
195 }
196
197 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
198 {
199 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
200 }
201
202 __device__ static constexpr auto HotLoopScheduler()
203 {
204 // A/B split schedule
205 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
206 constexpr auto num_ds_read_inst_a =
207 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
210
211 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
212 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
213 constexpr auto num_buffer_load_stage1 =
214 num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale;
215
216 constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a;
217
218 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
219 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
220
221 constexpr auto ds_read_a_issue_cycle =
222 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
223 constexpr auto ds_read_a_mfma_rate =
224 math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle);
225
226 // constexpr auto num_dsread_a_mfma =
227 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
228
229 constexpr auto num_total_stages = std::max(2, MRepeat);
230
231 if constexpr(num_total_stages > 2)
232 {
233 // Group num_mfma_perstage num_ds_read_a_perstage
234 // since we want to reuse a local register buffer
235 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
236 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
237
238 constexpr auto num_ds_read_a_mfma_perstage =
239 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
240
241 constexpr auto num_ds_read_a_prefetch_stages = 2;
242
243 constexpr auto buffer_load_perstage_more =
244 math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
245 constexpr auto buffer_load_perstage_less =
246 math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
247 constexpr auto buffer_load_perstage_stage2 =
248 math::integer_divide_floor((num_buffer_load_stage2), 2);
249
250 constexpr auto buffer_load_stages_more =
251 num_buffer_load_stage1 -
252 math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
253 ((num_total_stages - 2));
254
255 constexpr auto buffer_load_issue_point_interval_more =
256 num_mfma_perstage / buffer_load_perstage_more;
257 constexpr auto buffer_load_issue_point_interval_less =
258 num_mfma_perstage / buffer_load_perstage_less;
259 constexpr auto buffer_load_issue_point_interval_stage2 =
260 num_mfma_perstage / buffer_load_perstage_stage2;
261
262 // Stage 1
263 // global read more
265 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
266 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
267
268 if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
269 {
270 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
271 }
272
273 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
274 {
275 __builtin_amdgcn_sched_group_barrier(
276 0x100, ds_read_a_mfma_rate, 0); // DS read
277 }
278 });
279 });
280
281 // global read less
282 static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
283 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
284 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
285 if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
286 {
287 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
288 }
289 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
290 {
291 __builtin_amdgcn_sched_group_barrier(
292 0x100, ds_read_a_mfma_rate, 0); // DS read
293 }
294 });
295 });
296
297 // Stage 2, Sync
298 // lds synchronization, prefetch next loop local A
300 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
301 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
302 if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
303 {
304 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
305 }
306 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
307 {
308 __builtin_amdgcn_sched_group_barrier(
309 0x100, ds_read_a_mfma_rate, 0); // DS read
310 }
311 });
312 });
313 }
314 else
315 {
316 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
319 constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
320 num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
321
322 // stage 1
323 constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
324
325 constexpr auto mfma_perstage_more =
326 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
327 constexpr auto mfma_perstage_less =
328 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
329
330 constexpr auto mfma_stages_more =
331 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
332
334 if constexpr(i < mfma_stages_more)
335 {
337 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
338 });
339 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
340 }
341 else
342 {
344 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
345 });
346 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
347 }
348 });
349
351 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
352 {
353 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
354 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
355 });
356 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
357 }
358 else
359 {
360 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
361 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
362 });
363 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
364 }
365 });
366
368 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
369 mfma_stages_more)
370 {
371 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
372 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
373 });
374 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
375 }
376 else
377 {
378 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
379 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
380 });
381 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
382 }
383 });
384
386 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
387 num_buffer_load_a_scale) < mfma_stages_more)
388 {
389 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
390 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
391 });
392 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
393 }
394 else
395 {
396 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
397 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
398 });
399 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
400 }
401 });
402
403 // stage 2
405 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
406 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
407 ds_read_a_mfma_rate)
408 {
409 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
410 }
411 else
412 {
413 __builtin_amdgcn_sched_group_barrier(
414 0x100,
415 num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
416 0); // DS read
417 }
418 });
419 }
420 }
421
422 template <bool HasMainLoop,
423 TailNumber TailNum,
424 typename AGridDesc,
425 typename ABlockDesc,
426 typename ABlockTransfer,
427 typename AGridBuffer,
428 typename ABlockBuffer,
429 typename ABlockTransferStep,
430 typename BGridDesc,
431 typename BBlockDesc,
432 typename BBlockTransfer,
433 typename BGridBuffer,
434 typename BBlockBuffer,
435 typename BBlockTransferStep,
436 typename CThreadBuffer,
437 typename AScaleGridBuffer,
438 typename AScaleGridDesc,
439 typename AScaleThreadTransfer,
440 typename BScaleGridBuffer,
441 typename BScaleGridDesc,
442 typename BScaleThreadTransfer>
443 __device__ void Run(
444 // ABlockCopy
445 const AGridDesc& a_grid_desc,
446 const ABlockDesc& a_block_desc,
447 ABlockTransfer& a_blockwise_copy,
448 const AGridBuffer& a_grid_buf,
449 ABlockBuffer& a_block_bufs,
450 const ABlockTransferStep& a_block_copy_step,
451 // BBlockCopy
452 const BGridDesc& b_grid_desc,
453 const BBlockDesc& b_block_desc,
454 BBlockTransfer& b_blockwise_copy,
455 const BGridBuffer& b_grid_buf,
456 BBlockBuffer& b_block_bufs,
457 const BBlockTransferStep& b_block_copy_step,
458 // CThread
459 CThreadBuffer& c_thread_buf,
460 // A and B scales
461 const AScaleGridDesc& a_scale_grid_desc,
462 AScaleThreadTransfer& a_scale_thread_copy,
463 const AScaleGridBuffer& a_scale_grid_buf,
464 const BScaleGridDesc& b_scale_grid_desc,
465 BScaleThreadTransfer& b_scale_thread_copy,
466 const BScaleGridBuffer& b_scale_grid_buf,
467 index_t num_loop) const
468 {
469 ignore = b_block_bufs;
471 a_thread_desc_.GetElementSpaceSize());
473 b_thread_desc_.GetElementSpaceSize());
474 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
475 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
476
478 a_scale_thread_desc.GetElementSpaceSize());
479
481 b_scale_thread_desc.GetElementSpaceSize());
482
483 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
484 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
485
486 // Global prefetch 1
487 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
488 b_blockwise_copy.Run(
489 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
490
491 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
492 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
493
494 // Prefetch a_scales
495 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
496 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
497 a_scale_thread_copy.Run(a_scale_grid_desc,
498 a_scale_grid_buf,
500 make_tuple(m0, k0, I0),
501 a_scale_thread_bufs(I0));
502
503 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
504 make_multi_index(0, I1, 0));
505 });
506 a_scale_thread_copy.MoveSrcSliceWindow(
507 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
508 });
509
510 // restore row id and advance to the next set of scales
511 a_scale_thread_copy.MoveSrcSliceWindow(
512 a_scale_grid_desc,
513 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
514
515 // Prefetch b_scales
516 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
517 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
518 b_scale_thread_copy.Run(b_scale_grid_desc,
519 b_scale_grid_buf,
521 make_tuple(n0, k0, I0),
522 b_scale_thread_bufs(I0));
523
524 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
525 make_multi_index(0, I1, 0));
526 });
527 b_scale_thread_copy.MoveSrcSliceWindow(
528 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
529 });
530
531 // restore col id and advance to the next set of scales
532 // NWaves * NPerXDL * NRepeat == NPerBlock
533 b_scale_thread_copy.MoveSrcSliceWindow(
534 b_scale_grid_desc,
535 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
536
537 // Local prefetch 1, sync the async load
538 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
541 static_for<0, KRepeat, 1>{}([&](auto k) {
542 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
543 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
544 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
545 [&](auto chunk) {
546 constexpr auto a_k_step_chunk =
547 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
548 a_thread_copy_.Run(
552 a_block_bufs(I0),
556 a_thread_buf);
557 });
558 });
559 });
560
561 // Global prefetch 2
562 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
563 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
564
565 // Initialize C
566 c_thread_buf.Clear();
567 __builtin_amdgcn_sched_barrier(0);
568 constexpr index_t SwitchM = MRepeat - LocalPrefetchStages;
569 // main body
570 if constexpr(HasMainLoop)
571 {
572 // loop over k with the step KPerBlock
573 index_t i = 0;
574 do
575 {
576 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
577 b_blockwise_copy.Run(b_grid_desc,
578 b_grid_buf,
579 b_block_desc,
580 b_block_origin_idx,
581 b_thread_bufs(scale_mem_buf));
582
583 // Prefetch a_scales
584 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
585 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
586 a_scale_thread_copy.Run(a_scale_grid_desc,
587 a_scale_grid_buf,
589 make_tuple(m0, k0, I0),
590 a_scale_thread_bufs(scale_mem_buf));
591
592 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
593 make_multi_index(0, I1, 0));
594 });
595 a_scale_thread_copy.MoveSrcSliceWindow(
596 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
597 });
598
599 // restore row id and advance to the next set of scales
600 a_scale_thread_copy.MoveSrcSliceWindow(
601 a_scale_grid_desc,
602 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
603
604 // Prefetch b_scales
605 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
606 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
607 b_scale_thread_copy.Run(b_scale_grid_desc,
608 b_scale_grid_buf,
610 make_tuple(n0, k0, I0),
611 b_scale_thread_bufs(scale_mem_buf));
612
613 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
614 make_multi_index(0, I1, 0));
615 });
616 b_scale_thread_copy.MoveSrcSliceWindow(
617 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
618 });
619
620 // restore col id and advance to the next set of scales
621 // NWaves * NPerXDL * NRepeat == NPerBlock
622 b_scale_thread_copy.MoveSrcSliceWindow(
623 b_scale_grid_desc,
624 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
625
626 // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
627 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
628
629 static_for<0, MRepeat, 1>{}([&](auto m0) {
630 constexpr auto im_major = m0 / MXdlPack;
631 constexpr auto im_minor = m0 % MXdlPack;
632 static_for<0, KRepeat, 1>{}([&](auto k0) {
633 constexpr auto ik_major = k0 / KXdlPack;
634 constexpr auto ik_minor = k0 % KXdlPack;
635 static_for<0, NRepeat, 1>{}([&](auto n0) {
636 constexpr auto in_major = n0 / NXdlPack;
637 constexpr auto in_minor = n0 % NXdlPack;
638
639 constexpr index_t a_scale_offset =
640 a_scale_thread_desc.CalculateOffset(
641 make_tuple(im_major, ik_major, I0));
642 constexpr index_t b_scale_offset =
643 b_scale_thread_desc.CalculateOffset(
644 make_tuple(in_major, ik_major, I0));
645
646 static_assert(0 < ScalesPerXdlopsRunPerThread,
647 "Must have at least one scale per Xdlops "
648 "per Thread.");
649
651 a_scale_thread_vec;
653 b_scale_thread_vec;
654
655 // Pack scale_thread_buf into scale_thread_vec
657 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
658 a_scale_thread_bufs(
659 scale_comp_buf)[Number<a_scale_offset + s>{}];
660 });
661
663 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
664 b_scale_thread_bufs(
665 scale_comp_buf)[Number<b_scale_offset + s>{}];
666 });
667
670
671 static_for<0, KPack, 1>{}([&](auto ik) {
672 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
673 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
674 make_tuple(I0, I0, im_minor, k0, ik))>{}];
675 b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
676 [scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
677 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
678 });
679
680 using mfma_input_type_a =
681 typename vector_type<ComputeTypeA,
682 xdlops_gemm.K1PerXdlops /
683 APackedSize>::type;
684
685 using mfma_input_type_b =
686 typename vector_type<ComputeTypeB,
687 xdlops_gemm.K1PerXdlops /
688 BPackedSize>::type;
689
690 using mfma_scale_input_type_a =
691 typename vector_type<AScaleDataType,
693 using mfma_scale_input_type_b =
694 typename vector_type<BScaleDataType,
696
697 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
698 make_tuple(im_major, in_major, im_minor, in_minor, 0));
699
700 // MFMA accumulation
701 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
702 ik_minor * NXdlPack + in_minor>(
703 a_thread_vec.template AsType<mfma_input_type_a>(),
704 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
705 b_thread_vec.template AsType<mfma_input_type_b>(),
706 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
707 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
708 });
709 });
710
711 if constexpr(m0.value == SwitchM)
712 {
713 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
715 a_blockwise_copy.Run(a_grid_desc,
716 a_grid_buf,
717 a_block_desc,
718 a_block_bufs(scale_comp_buf));
719 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
720 }
721
722 constexpr auto lds_buf =
723 m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf;
724
725 static_for<0, KRepeat, 1>{}([&](auto k) {
726 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
727 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
728 static_for<0,
729 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
730 1>{}([&](auto chunk) {
731 constexpr auto a_k_step_chunk =
732 k_step +
733 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
734 a_thread_copy_.Run(
737 (MRepeat / MXdlPack)>{},
738 I0,
740 I0,
742 a_block_bufs(Number<lds_buf>{}),
745 I0,
747 k,
749 a_thread_buf);
750 });
751 });
752 });
753
755 __builtin_amdgcn_sched_barrier(0);
756 };
757
758 LoopFunc(I0, I1);
759 LoopFunc(I1, I0);
760
761 i += 2;
762 } while(i < (num_loop - 2));
763 }
764
765 // tail
766 if constexpr(TailNum == TailNumber::Even)
767 {
768 b_blockwise_copy.Run(
769 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
770
771 // Prefetch a_scales
772 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
773 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
774 a_scale_thread_copy.Run(a_scale_grid_desc,
775 a_scale_grid_buf,
777 make_tuple(m0, k0, I0),
778 a_scale_thread_bufs(I1));
779
780 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
781 make_multi_index(0, I1, 0));
782 });
783 a_scale_thread_copy.MoveSrcSliceWindow(
784 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
785 });
786
787 // Prefetch b_scales
788 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
789 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
790 b_scale_thread_copy.Run(b_scale_grid_desc,
791 b_scale_grid_buf,
793 make_tuple(n0, k0, I0),
794 b_scale_thread_bufs(I1));
795
796 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
797 make_multi_index(0, I1, 0));
798 });
799 b_scale_thread_copy.MoveSrcSliceWindow(
800 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
801 });
802
803 static_for<0, MRepeat, 1>{}([&](auto m0) {
804 constexpr auto im_major = m0 / MXdlPack;
805 constexpr auto im_minor = m0 % MXdlPack;
806 static_for<0, KRepeat, 1>{}([&](auto k0) {
807 constexpr auto ik_major = k0 / KXdlPack;
808 constexpr auto ik_minor = k0 % KXdlPack;
809 static_for<0, NRepeat, 1>{}([&](auto n0) {
810 constexpr auto in_major = n0 / NXdlPack;
811 constexpr auto in_minor = n0 % NXdlPack;
812
813 constexpr index_t a_scale_offset =
814 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
815 constexpr index_t b_scale_offset =
816 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
817
818 static_assert(0 < ScalesPerXdlopsRunPerThread,
819 "Must have at least one scale per Xdlops "
820 "per Thread.");
821
824
825 // Pack scale_thread_buf into scale_thread_vec
827 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
828 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
829 });
830
832 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
833 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
834 });
835
838
839 static_for<0, KPack, 1>{}([&](auto ik) {
840 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
841 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
842 make_tuple(I0, I0, im_minor, k0, ik))>{}];
843 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
844 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
845 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
846 });
847
848 using mfma_input_type_a =
849 typename vector_type<ComputeTypeA,
850 xdlops_gemm.K1PerXdlops / APackedSize>::type;
851
852 using mfma_input_type_b =
853 typename vector_type<ComputeTypeB,
854 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
855
856 using mfma_scale_input_type_a =
858 using mfma_scale_input_type_b =
860
861 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
862 make_tuple(im_major, in_major, im_minor, in_minor, 0));
863
864 // MFMA accumulation
865 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
866 ik_minor * NXdlPack + in_minor>(
867 a_thread_vec.template AsType<mfma_input_type_a>(),
868 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
869 b_thread_vec.template AsType<mfma_input_type_b>(),
870 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
871 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
872 });
873 });
874 if constexpr(m0.value == SwitchM)
875 {
876 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
878 }
879
880 constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
881
882 static_for<0, KRepeat, 1>{}([&](auto k) {
883 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
884 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
885 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
886 [&](auto chunk) {
887 constexpr auto a_k_step_chunk =
888 k_step +
889 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
890 a_thread_copy_.Run(
893 (MRepeat / MXdlPack)>{},
894 I0,
896 I0,
898 a_block_bufs(Number<lds_buf>{}),
902 a_thread_buf);
903 });
904 });
905 });
906
907 static_for<0, MRepeat, 1>{}([&](auto m0) {
908 constexpr auto im_major = m0 / MXdlPack;
909 constexpr auto im_minor = m0 % MXdlPack;
910 static_for<0, KRepeat, 1>{}([&](auto k0) {
911 constexpr auto ik_major = k0 / KXdlPack;
912 constexpr auto ik_minor = k0 % KXdlPack;
913 static_for<0, NRepeat, 1>{}([&](auto n0) {
914 constexpr auto in_major = n0 / NXdlPack;
915 constexpr auto in_minor = n0 % NXdlPack;
916
917 constexpr index_t a_scale_offset =
918 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
919 constexpr index_t b_scale_offset =
920 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
921
922 static_assert(0 < ScalesPerXdlopsRunPerThread,
923 "Must have at least one scale per Xdlops "
924 "per Thread.");
925
928
929 // Pack scale_thread_buf into scale_thread_vec
931 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
932 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
933 });
934
936 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
937 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
938 });
939
942
943 static_for<0, KPack, 1>{}([&](auto ik) {
944 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
945 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
946 make_tuple(I0, I0, im_minor, k0, ik))>{}];
947 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
948 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
949 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
950 });
951
952 using mfma_input_type_a =
953 typename vector_type<ComputeTypeA,
954 xdlops_gemm.K1PerXdlops / APackedSize>::type;
955
956 using mfma_input_type_b =
957 typename vector_type<ComputeTypeB,
958 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
959
960 using mfma_scale_input_type_a =
962 using mfma_scale_input_type_b =
964
965 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
966 make_tuple(im_major, in_major, im_minor, in_minor, 0));
967
968 // MFMA accumulation
969 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
970 ik_minor * NXdlPack + in_minor>(
971 a_thread_vec.template AsType<mfma_input_type_a>(),
972 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
973 b_thread_vec.template AsType<mfma_input_type_b>(),
974 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
975 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
976 });
977 });
978 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
979 {
980 static_for<0, KRepeat, 1>{}([&](auto k) {
981 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
982 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
983 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
984 [&](auto chunk) {
985 constexpr auto a_k_step_chunk =
986 k_step +
987 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
988 a_thread_copy_.Run(
991 (MRepeat / MXdlPack)>{},
992 I0,
994 I0,
996 a_block_bufs(I1),
999 I0,
1001 k,
1003 a_thread_buf);
1004 });
1005 });
1006 }
1007 });
1008 }
1009 else if constexpr(TailNum == TailNumber::Odd)
1010 {
1011 static_for<0, MRepeat, 1>{}([&](auto m0) {
1012 constexpr auto im_major = m0 / MXdlPack;
1013 constexpr auto im_minor = m0 % MXdlPack;
1014 static_for<0, KRepeat, 1>{}([&](auto k0) {
1015 constexpr auto ik_major = k0 / KXdlPack;
1016 constexpr auto ik_minor = k0 % KXdlPack;
1017 static_for<0, NRepeat, 1>{}([&](auto n0) {
1018 constexpr auto in_major = n0 / NXdlPack;
1019 constexpr auto in_minor = n0 % NXdlPack;
1020
1021 constexpr index_t a_scale_offset =
1022 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
1023 constexpr index_t b_scale_offset =
1024 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
1025
1026 static_assert(0 < ScalesPerXdlopsRunPerThread,
1027 "Must have at least one scale per Xdlops "
1028 "per Thread.");
1029
1032
1033 // Pack scale_thread_buf into scale_thread_vec
1035 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1036 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1037 });
1038
1040 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1041 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1042 });
1043
1046
1047 static_for<0, KPack, 1>{}([&](auto ik) {
1048 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1049 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1050 make_tuple(I0, I0, im_minor, k0, ik))>{}];
1051 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1052 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1053 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
1054 });
1055
1056 using mfma_input_type_a =
1057 typename vector_type<ComputeTypeA,
1058 xdlops_gemm.K1PerXdlops / APackedSize>::type;
1059
1060 using mfma_input_type_b =
1061 typename vector_type<ComputeTypeB,
1062 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
1063
1064 using mfma_scale_input_type_a =
1066 using mfma_scale_input_type_b =
1068
1069 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1070 make_tuple(im_major, in_major, im_minor, in_minor, 0));
1071
1072 // MFMA accumulation
1073 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
1074 ik_minor * NXdlPack + in_minor>(
1075 a_thread_vec.template AsType<mfma_input_type_a>(),
1076 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
1077 b_thread_vec.template AsType<mfma_input_type_b>(),
1078 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
1079 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1080 });
1081 });
1082 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
1083 {
1084 static_for<0, KRepeat, 1>{}([&](auto k) {
1085 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
1086 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
1087 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1088 [&](auto chunk) {
1089 constexpr auto a_k_step_chunk =
1090 k_step +
1091 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1092 a_thread_copy_.Run(
1095 (MRepeat / MXdlPack)>{},
1096 I0,
1098 I0,
1100 a_block_bufs(I0),
1102 make_tuple(I0,
1103 I0,
1105 k,
1107 a_thread_buf);
1108 });
1109 });
1110 }
1111 });
1112 }
1113 }
1114
1115 // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack]
1116 // Order: 1 0 3 2 4
1117 static constexpr auto ARegBuf = 2;
1120
1124 decltype(a_thread_desc_),
1127 4,
1128 A_K1,
1129 A_K1>;
1131
1132 // TODO: make this field protected when a_scale_thread_copy_ is moved
1133 // here
1136 Number<KRepeat / KXdlPack>{},
1138
1139 // TODO: make this field protected when b_scale_thread_copy_ is moved
1140 // here
1143 Number<KRepeat / KXdlPack>{},
1145
1146 protected:
1147 // using Base::a_thread_copy_;
1148 // using Base::a_thread_desc_;
1149 using Base::b_thread_copy_;
1150 using Base::b_thread_desc_;
1151 using Base::c_thread_desc_;
1152};
1153
1154} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_bufs, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp:443
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_m3_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, KThreadChunk >, Sequence< 0, 1, 2, 3, 4 >, 4, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp:1121
Definition blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp:38
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10