blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
38{
39};
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126 using typename Base::HotLoopInstList;
127
136 using Base::GetWaveIdx;
139
142
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
148
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
152
153 using AccType = typename Base::AccType;
154 using Tuple5 = typename Base::Tuple5;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161
162 static constexpr auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
164
165 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
166 static constexpr auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
168
169 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
170 static constexpr auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
172
174 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
175 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
182
183 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
184 {
185 return num_loop > PrefetchStages;
186 }
187
188 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
189 {
190 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
191 }
192
193 __device__ static constexpr auto HotLoopScheduler()
194 {
195 // A/B split schedule
196 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
197 constexpr auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
201 constexpr auto num_ds_read_inst_b =
202 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
205
206 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
207 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num * 2;
208
209 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
210 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
211
212 constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
213 constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
214
215 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2;
216
217 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
218 constexpr auto ds_read_a_issue_cycle =
219 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
220 constexpr auto ds_read_b_issue_cycle =
221 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
222
223 constexpr auto ds_read_a_mfma_rate =
224 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
225 constexpr auto ds_read_b_mfma_rate =
226 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
227
228 constexpr auto num_dsread_a_mfma =
229 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
230 constexpr auto num_dsread_b_mfma =
231 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
232
233 // stage 1
234 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
235 constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
236 num_buffer_load_a_scale + num_buffer_load_b_scale;
237
238 constexpr auto mfma_perstage_more =
239 math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
240 constexpr auto mfma_perstage_less =
241 math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
242
243 constexpr auto mfma_stages_more =
244 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
245
246 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
247 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
248
250 if constexpr(i < mfma_stages_more)
251 {
252 static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
253 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
254 if constexpr(imfma < num_dswrite_per_issue_a)
255 {
256 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
257 }
258 });
259 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
260 }
261 else
262 {
263 static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 if constexpr(imfma < num_dswrite_per_issue_a)
266 {
267 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
268 }
269 });
270 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
271 }
272 });
273
275 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
276 {
277 static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
278 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
279 if constexpr(imfma < num_dswrite_per_issue_a)
280 {
281 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
282 }
283 });
284 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
285 }
286 else
287 {
288 static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
289 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
290 if constexpr(imfma < num_dswrite_per_issue_b)
291 {
292 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
293 }
294 });
295 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
296 }
297 });
298
300 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
301 {
302 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
303 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
304 });
305 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
306 }
307 else
308 {
309 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
310 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
311 });
312 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
313 }
314 });
315
317 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
318 num_buffer_load_a_scale) < mfma_stages_more)
319 {
320 static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
321 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
322 });
323 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
324 }
325 else
326 {
327 static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
328 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
329 });
330 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
331 }
332 });
333
334 // stage 2
336 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
337 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
338 ds_read_a_mfma_rate)
339 {
340 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
341 }
342 else
343 {
344 __builtin_amdgcn_sched_group_barrier(0x100,
345 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
346 ds_read_a_mfma_rate,
347 0); // DS read
348 }
349 });
350
352 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
353 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
354 ds_read_b_mfma_rate)
355 {
356 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
357 }
358 else
359 {
360 __builtin_amdgcn_sched_group_barrier(0x100,
361 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
362 ds_read_b_mfma_rate,
363 0); // DS read
364 }
365 });
366 }
367
368 template <bool HasMainLoop,
369 TailNumber TailNum,
370 typename AGridDesc,
371 typename ABlockDesc,
372 typename ABlockTransfer,
373 typename AGridBuffer,
374 typename ABlockBuffer,
375 typename ABlockTransferStep,
376 typename BGridDesc,
377 typename BBlockDesc,
378 typename BBlockTransfer,
379 typename BGridBuffer,
380 typename BBlockBuffer,
381 typename BBlockTransferStep,
382 typename CThreadBuffer,
383 typename AScaleGridBuffer,
384 typename AScaleGridDesc,
385 typename AScaleThreadTransfer,
386 typename BScaleGridBuffer,
387 typename BScaleGridDesc,
388 typename BScaleThreadTransfer>
389 __device__ void Run(
390 // A
391 const AGridDesc& a_grid_desc,
392 const ABlockDesc& a_block_desc,
393 ABlockTransfer& a_blockwise_copy,
394 const AGridBuffer& a_grid_buf,
395 ABlockBuffer& a_block_buf,
396 const ABlockTransferStep& a_block_copy_step,
397 // B0/B1
398 const BGridDesc& b_grid_desc,
399 const BBlockDesc& b_block_desc,
400 BBlockTransfer& b_blockwise_copy,
401 BBlockTransfer& b_blockwise_copy_up,
402 const BGridBuffer& b_grid_buf,
403 const BGridBuffer& b_grid_buf_up,
404 BBlockBuffer& b_block_buf,
405 BBlockBuffer& b_block_buf_up,
406 const BBlockTransferStep& b_block_copy_step,
407 // C
408 CThreadBuffer& c_thread_buf,
409 CThreadBuffer& c_thread_buf_up,
410 // A scale
411 const AScaleGridDesc& a_scale_grid_desc,
412 AScaleThreadTransfer& a_scale_thread_copy,
413 const AScaleGridBuffer& a_scale_grid_buf,
414 // B0/B1 scale
415 const BScaleGridDesc& b_scale_grid_desc,
416 BScaleThreadTransfer& b_scale_thread_copy,
417 BScaleThreadTransfer& b_scale_thread_copy_up,
418 const BScaleGridBuffer& b_scale_grid_buf,
419 const BScaleGridBuffer& b_scale_grid_buf_up,
420 index_t num_loop) const
421 {
423 a_thread_desc_.GetElementSpaceSize());
425 b_thread_desc_.GetElementSpaceSize());
427 b_thread_desc_.GetElementSpaceSize());
428
430 a_scale_thread_desc.GetElementSpaceSize());
432 b_scale_thread_desc.GetElementSpaceSize());
434 b_scale_thread_desc.GetElementSpaceSize());
435
436 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
437 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
438 StaticallyIndexedArray<decltype(b_scale_thread_buf_up), Number<2>{}> b_scale_thread_bufs_up;
439
440 // Global prefetch 1
441 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
442 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
443 b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up);
444
445 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
446 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
447 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
448
449 // Prefetch a_scales
450 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
451 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
452 a_scale_thread_copy.Run(a_scale_grid_desc,
453 a_scale_grid_buf,
455 make_tuple(m0, k0, I0),
456 a_scale_thread_bufs(I0));
457
458 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
459 make_multi_index(0, I1, 0));
460 });
461 a_scale_thread_copy.MoveSrcSliceWindow(
462 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
463 });
464
465 // restore row id and advance to the next set of scales
466 a_scale_thread_copy.MoveSrcSliceWindow(
467 a_scale_grid_desc,
468 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
469
470 // Prefetch b_scales
471 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
472 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
473 b_scale_thread_copy.Run(b_scale_grid_desc,
474 b_scale_grid_buf,
476 make_tuple(n0, k0, I0),
477 b_scale_thread_bufs(I0));
478
479 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
480 make_multi_index(0, I1, 0));
481 });
482 b_scale_thread_copy.MoveSrcSliceWindow(
483 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
484 });
485
486 // restore col id and advance to the next set of scales
487 // NWaves * NPerXDL * NRepeat == NPerBlock
488 b_scale_thread_copy.MoveSrcSliceWindow(
489 b_scale_grid_desc,
490 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
491
492 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
493 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
494 b_scale_thread_copy_up.Run(b_scale_grid_desc,
495 b_scale_grid_buf_up,
497 make_tuple(n0, k0, I0),
498 b_scale_thread_bufs_up(I0));
499
500 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
501 make_multi_index(0, I1, 0));
502 });
503 b_scale_thread_copy_up.MoveSrcSliceWindow(
504 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
505 });
506
507 // restore col id and advance to the next set of scales
508 // NWaves * NPerXDL * NRepeat == NPerBlock
509 b_scale_thread_copy_up.MoveSrcSliceWindow(
510 b_scale_grid_desc,
511 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
512
513 // Local prefill 1
514 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
515 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
516 b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up);
517
518 // Global prefetch 2
519 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
520 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
521 b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up);
522
523 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
524 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
525 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
526
527 // Local prefetch 1
529 static_for<0, KRepeat, 1>{}([&](auto k) {
530 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
531 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
532 static_for<0, MRepeat, 1>{}([&](auto m0) {
533 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
534 [&](auto chunk) {
535 constexpr auto a_k_step_chunk =
536 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
539 I0,
541 I0,
543 a_block_buf,
546 I0,
548 k,
550 a_thread_buf);
551 });
552 });
553 static_for<0, NRepeat, 1>{}([&](auto n0) {
554 // read block data in chunks to assemble correct thread vectors
555 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
556 [&](auto chunk) {
557 constexpr auto b_k_step_chunk =
558 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
561 I0,
563 I0,
565 b_block_buf,
568 I0,
570 k,
572 b_thread_buf);
573 });
574 });
575 static_for<0, NRepeat, 1>{}([&](auto n0) {
576 // read block data in chunks to assemble correct thread vectors
577 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
578 [&](auto chunk) {
579 constexpr auto b_k_step_chunk =
580 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
583 I0,
585 I0,
587 b_block_buf_up,
590 I0,
592 k,
594 b_thread_buf_up);
595 });
596 });
597 });
598
599 // Initialize C
600 c_thread_buf.Clear();
601 c_thread_buf_up.Clear();
602 __builtin_amdgcn_sched_barrier(0);
603
604 // main body
605 if constexpr(HasMainLoop)
606 {
607 // loop over k with the step KPerBlock
608 index_t i = 0;
609 do
610 {
611 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
613
614 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
615 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
616
617 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
618 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
619
620 b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up);
621 b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up);
622
623 // Prefetch a_scales
624 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
625 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
626 a_scale_thread_copy.Run(a_scale_grid_desc,
627 a_scale_grid_buf,
629 make_tuple(m0, k0, I0),
630 a_scale_thread_bufs(scale_mem_buf));
631
632 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
633 make_multi_index(0, I1, 0));
634 });
635 a_scale_thread_copy.MoveSrcSliceWindow(
636 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
637 });
638
639 // restore row id and advance to the next set of scales
640 a_scale_thread_copy.MoveSrcSliceWindow(
641 a_scale_grid_desc,
642 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
643
644 // Prefetch b_scales
645 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
646 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
647 b_scale_thread_copy.Run(b_scale_grid_desc,
648 b_scale_grid_buf,
650 make_tuple(n0, k0, I0),
651 b_scale_thread_bufs(scale_mem_buf));
652
653 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
654 make_multi_index(0, I1, 0));
655 });
656 b_scale_thread_copy.MoveSrcSliceWindow(
657 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
658 });
659
660 // restore col id and advance to the next set of scales
661 // NWaves * NPerXDL * NRepeat == NPerBlock
662 b_scale_thread_copy.MoveSrcSliceWindow(
663 b_scale_grid_desc,
664 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
665
666 // Prefetch b_scales_up
667 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
668 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
669 b_scale_thread_copy_up.Run(b_scale_grid_desc,
670 b_scale_grid_buf_up,
672 make_tuple(n0, k0, I0),
673 b_scale_thread_bufs_up(scale_mem_buf));
674
675 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
676 make_multi_index(0, I1, 0));
677 });
678 b_scale_thread_copy_up.MoveSrcSliceWindow(
679 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
680 });
681
682 // restore col id and advance to the next set of scales
683 // NWaves * NPerXDL * NRepeat == NPerBlock
684 b_scale_thread_copy_up.MoveSrcSliceWindow(
685 b_scale_grid_desc,
686 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
687
688 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
689 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
690 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
691
692 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
693 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
694 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
695 constexpr index_t a_scale_offset =
696 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
697 constexpr index_t b_scale_offset =
698 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
699
700 static_assert(0 < ScalesPerXdlopsRunPerThread,
701 "Must have at least one scale per Xdlops "
702 "per Thread.");
703
705 a_scale_thread_vec;
707 b_scale_thread_vec;
709 b_scale_thread_vec_up;
710
711 // Pack scale_thread_buf into scale_thread_vec
713 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
714 a_scale_thread_bufs(
715 scale_comp_buf)[Number<a_scale_offset + s>{}];
716 });
717
719 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
720 b_scale_thread_bufs(
721 scale_comp_buf)[Number<b_scale_offset + s>{}];
722 });
723
725 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
726 b_scale_thread_bufs_up(
727 scale_comp_buf)[Number<b_scale_offset + s>{}];
728 });
729
730 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
731 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
732 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
733 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
734
737 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
738
739 static_for<0, KPack, 1>{}([&](auto ik) {
740 a_thread_vec.template AsType<ComputeTypeA>()(
741 ik) = a_thread_buf
742 [Number<a_thread_desc_.CalculateOffset(
743 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
744 b_thread_vec.template AsType<ComputeTypeB>()(
745 ik) = b_thread_buf
746 [Number<b_thread_desc_.CalculateOffset(
747 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
748 b_thread_vec_up.template AsType<ComputeTypeB>()(
749 ik) = b_thread_buf_up
750 [Number<b_thread_desc_.CalculateOffset(
751 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
752 });
753
754 using mfma_input_type_a =
755 typename vector_type<ComputeTypeA,
756 xdlops_gemm.K1PerXdlops /
757 APackedSize>::type;
758
759 using mfma_input_type_b =
760 typename vector_type<ComputeTypeB,
761 xdlops_gemm.K1PerXdlops /
762 BPackedSize>::type;
763
764 using mfma_scale_input_type_a =
765 typename vector_type<AScaleDataType,
767 using mfma_scale_input_type_b =
768 typename vector_type<BScaleDataType,
770
771 constexpr index_t c_offset =
772 c_thread_desc_.CalculateOffset(
773 make_tuple(m0, n0, imxdl, inxdl, 0));
774
775 // MFMA accumulation
776 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
777 ikxdl * NXdlPack + inxdl>(
778 a_thread_vec.template AsType<mfma_input_type_a>(),
779 a_scale_thread_vec
780 .template AsType<mfma_scale_input_type_a>(),
781 b_thread_vec.template AsType<mfma_input_type_b>(),
782 b_scale_thread_vec
783 .template AsType<mfma_scale_input_type_b>(),
784 c_thread_buf.GetVectorTypeReference(
786
787 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
788 ikxdl * NXdlPack + inxdl>(
789 a_thread_vec.template AsType<mfma_input_type_a>(),
790 a_scale_thread_vec
791 .template AsType<mfma_scale_input_type_a>(),
792 b_thread_vec_up
793 .template AsType<mfma_input_type_b>(),
794 b_scale_thread_vec_up
795 .template AsType<mfma_scale_input_type_b>(),
796 c_thread_buf_up.GetVectorTypeReference(
798 });
799 });
800 });
801 });
802 });
803 });
804
805 // k indexes mapping to threads for 32x32x64:
806 // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
807 // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
808 // k = 0 k = 1
809
810 // k indexes mapping to threads for 16x16x128:
811 // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
812 // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
813 // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
814 // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
815 // k = 0 k = 1
817 static_for<0, KRepeat, 1>{}([&](auto k) {
818 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
819 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
820 static_for<0, MRepeat, 1>{}([&](auto m0) {
821 static_for<0,
822 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
823 1>{}([&](auto chunk) {
824 constexpr auto a_k_step_chunk =
825 k_step +
826 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
829 I0,
831 I0,
833 a_block_buf,
836 I0,
838 k,
840 a_thread_buf);
841 });
842 });
843 static_for<0, NRepeat, 1>{}([&](auto n0) {
844 // read block data in chunks to assemble correct thread vectors
845 static_for<0,
846 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
847 1>{}([&](auto chunk) {
848 constexpr auto b_k_step_chunk =
849 k_step +
850 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
853 I0,
855 I0,
857 b_block_buf,
860 I0,
862 k,
864 b_thread_buf);
865 });
866 });
867 static_for<0, NRepeat, 1>{}([&](auto n0) {
868 // read block data in chunks to assemble correct thread vectors
869 static_for<0,
870 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
871 1>{}([&](auto chunk) {
872 constexpr auto b_k_step_chunk =
873 k_step +
874 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
877 I0,
879 I0,
881 b_block_buf_up,
884 I0,
886 k,
888 b_thread_buf_up);
889 });
890 });
891 });
892
894 __builtin_amdgcn_sched_barrier(0);
895 };
896
897 LoopFunc(I0, I1);
898 LoopFunc(I1, I0);
899
900 i += 2;
901 } while(i < (num_loop - 2));
902 }
903
904 // tail
905 if constexpr(TailNum == TailNumber::Even)
906 {
907 // Prefetch a_scales
908 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
909 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
910 a_scale_thread_copy.Run(a_scale_grid_desc,
911 a_scale_grid_buf,
913 make_tuple(m0, k0, I0),
914 a_scale_thread_bufs(I1));
915
916 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
917 make_multi_index(0, I1, 0));
918 });
919 a_scale_thread_copy.MoveSrcSliceWindow(
920 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
921 });
922
923 // Prefetch b_scales
924 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
925 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
926 b_scale_thread_copy.Run(b_scale_grid_desc,
927 b_scale_grid_buf,
929 make_tuple(n0, k0, I0),
930 b_scale_thread_bufs(I1));
931
932 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
933 make_multi_index(0, I1, 0));
934 });
935 b_scale_thread_copy.MoveSrcSliceWindow(
936 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
937 });
938
939 // Prefetch b_scales_up
940 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
941 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
942 b_scale_thread_copy_up.Run(b_scale_grid_desc,
943 b_scale_grid_buf_up,
945 make_tuple(n0, k0, I0),
946 b_scale_thread_bufs_up(I1));
947
948 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
949 make_multi_index(0, I1, 0));
950 });
951 b_scale_thread_copy_up.MoveSrcSliceWindow(
952 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
953 });
954
956 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
957 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
958 b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up);
959
960 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
961 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
962 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
963 constexpr index_t a_scale_offset =
964 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
965 constexpr index_t b_scale_offset =
966 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
967
968 static_assert(0 < ScalesPerXdlopsRunPerThread,
969 "Must have at least one scale per Xdlops "
970 "per Thread.");
971
975
976 // Pack scale_thread_buf into scale_thread_vec
978 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
979 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
980 });
981
983 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
984 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
985 });
986
988 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
989 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
990 });
991
992 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
993 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
994 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
995 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
996
999 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1000
1001 static_for<0, KPack, 1>{}([&](auto ik) {
1002 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1003 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1004 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1005 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1006 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1007 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1008 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1009 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
1010 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1011 });
1012
1013 using mfma_input_type_a =
1014 typename vector_type<ComputeTypeA,
1015 xdlops_gemm.K1PerXdlops /
1016 APackedSize>::type;
1017
1018 using mfma_input_type_b =
1019 typename vector_type<ComputeTypeB,
1020 xdlops_gemm.K1PerXdlops /
1021 BPackedSize>::type;
1022
1023 using mfma_scale_input_type_a =
1024 typename vector_type<AScaleDataType,
1026 using mfma_scale_input_type_b =
1027 typename vector_type<BScaleDataType,
1029
1030 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1031 make_tuple(m0, n0, imxdl, inxdl, 0));
1032
1033 // MFMA accumulation
1034 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1035 ikxdl * NXdlPack + inxdl>(
1036 a_thread_vec.template AsType<mfma_input_type_a>(),
1037 a_scale_thread_vec
1038 .template AsType<mfma_scale_input_type_a>(),
1039 b_thread_vec.template AsType<mfma_input_type_b>(),
1040 b_scale_thread_vec
1041 .template AsType<mfma_scale_input_type_b>(),
1042 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1043
1044 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1045 ikxdl * NXdlPack + inxdl>(
1046 a_thread_vec.template AsType<mfma_input_type_a>(),
1047 a_scale_thread_vec
1048 .template AsType<mfma_scale_input_type_a>(),
1049 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1050 b_scale_thread_vec_up
1051 .template AsType<mfma_scale_input_type_b>(),
1052 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1053 });
1054 });
1055 });
1056 });
1057 });
1058 });
1059
1061
1062 static_for<0, KRepeat, 1>{}([&](auto k) {
1063 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
1064 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
1065 static_for<0, MRepeat, 1>{}([&](auto m0) {
1066 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1067 [&](auto chunk) {
1068 constexpr auto a_k_step_chunk =
1069 k_step +
1070 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1073 I0,
1075 I0,
1077 a_block_buf,
1080 I0,
1082 k,
1084 a_thread_buf);
1085 });
1086 });
1087 static_for<0, NRepeat, 1>{}([&](auto n0) {
1088 // read block data in chunks to assemble correct thread vectors
1089 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
1090 [&](auto chunk) {
1091 constexpr auto b_k_step_chunk =
1092 k_step +
1093 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1096 I0,
1098 I0,
1100 b_block_buf,
1103 I0,
1105 k,
1107 b_thread_buf);
1108 });
1109 });
1110 static_for<0, NRepeat, 1>{}([&](auto n0) {
1111 // read block data in chunks to assemble correct thread vectors
1112 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
1113 [&](auto chunk) {
1114 constexpr auto b_k_step_chunk =
1115 k_step +
1116 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1119 I0,
1121 I0,
1123 b_block_buf_up,
1126 I0,
1128 k,
1130 b_thread_buf_up);
1131 });
1132 });
1133 });
1134
1135 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1136 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1137 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1138 constexpr index_t a_scale_offset =
1139 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1140 constexpr index_t b_scale_offset =
1141 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1142
1143 static_assert(0 < ScalesPerXdlopsRunPerThread,
1144 "Must have at least one scale per Xdlops "
1145 "per Thread.");
1146
1150
1151 // Pack scale_thread_buf into scale_thread_vec
1153 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1154 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
1155 });
1156
1158 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1159 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
1160 });
1161
1163 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
1164 b_scale_thread_bufs_up(I1)[Number<b_scale_offset + s>{}];
1165 });
1166
1167 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1168 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1169 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1170 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1171
1174 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1175
1176 static_for<0, KPack, 1>{}([&](auto ik) {
1177 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1178 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1179 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1180 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1181 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1182 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1183 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1184 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
1185 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1186 });
1187
1188 using mfma_input_type_a =
1189 typename vector_type<ComputeTypeA,
1190 xdlops_gemm.K1PerXdlops /
1191 APackedSize>::type;
1192
1193 using mfma_input_type_b =
1194 typename vector_type<ComputeTypeB,
1195 xdlops_gemm.K1PerXdlops /
1196 BPackedSize>::type;
1197
1198 using mfma_scale_input_type_a =
1199 typename vector_type<AScaleDataType,
1201 using mfma_scale_input_type_b =
1202 typename vector_type<BScaleDataType,
1204
1205 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1206 make_tuple(m0, n0, imxdl, inxdl, 0));
1207
1208 // MFMA accumulation
1209 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1210 ikxdl * NXdlPack + inxdl>(
1211 a_thread_vec.template AsType<mfma_input_type_a>(),
1212 a_scale_thread_vec
1213 .template AsType<mfma_scale_input_type_a>(),
1214 b_thread_vec.template AsType<mfma_input_type_b>(),
1215 b_scale_thread_vec
1216 .template AsType<mfma_scale_input_type_b>(),
1217 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1218
1219 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1220 ikxdl * NXdlPack + inxdl>(
1221 a_thread_vec.template AsType<mfma_input_type_a>(),
1222 a_scale_thread_vec
1223 .template AsType<mfma_scale_input_type_a>(),
1224 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1225 b_scale_thread_vec_up
1226 .template AsType<mfma_scale_input_type_b>(),
1227 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1228 });
1229 });
1230 });
1231 });
1232 });
1233 });
1234 }
1235 else if constexpr(TailNum == TailNumber::Odd)
1236 {
1237 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
1238 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
1239 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
1240 constexpr index_t a_scale_offset =
1241 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
1242 constexpr index_t b_scale_offset =
1243 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
1244
1245 static_assert(0 < ScalesPerXdlopsRunPerThread,
1246 "Must have at least one scale per Xdlops "
1247 "per Thread.");
1248
1252
1253 // Pack scale_thread_buf into scale_thread_vec
1255 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1256 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1257 });
1258
1260 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1261 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1262 });
1263
1265 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
1266 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
1267 });
1268
1269 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1270 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1271 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1272 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
1273
1276 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1277
1278 static_for<0, KPack, 1>{}([&](auto ik) {
1279 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1280 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1281 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
1282 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1283 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1284 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1285 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1286 b_thread_buf_up[Number<b_thread_desc_.CalculateOffset(
1287 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
1288 });
1289
1290 using mfma_input_type_a =
1291 typename vector_type<ComputeTypeA,
1292 xdlops_gemm.K1PerXdlops /
1293 APackedSize>::type;
1294
1295 using mfma_input_type_b =
1296 typename vector_type<ComputeTypeB,
1297 xdlops_gemm.K1PerXdlops /
1298 BPackedSize>::type;
1299
1300 using mfma_scale_input_type_a =
1301 typename vector_type<AScaleDataType,
1303 using mfma_scale_input_type_b =
1304 typename vector_type<BScaleDataType,
1306
1307 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1308 make_tuple(m0, n0, imxdl, inxdl, 0));
1309
1310 // MFMA accumulation
1311 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1312 ikxdl * NXdlPack + inxdl>(
1313 a_thread_vec.template AsType<mfma_input_type_a>(),
1314 a_scale_thread_vec
1315 .template AsType<mfma_scale_input_type_a>(),
1316 b_thread_vec.template AsType<mfma_input_type_b>(),
1317 b_scale_thread_vec
1318 .template AsType<mfma_scale_input_type_b>(),
1319 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1320
1321 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1322 ikxdl * NXdlPack + inxdl>(
1323 a_thread_vec.template AsType<mfma_input_type_a>(),
1324 a_scale_thread_vec
1325 .template AsType<mfma_scale_input_type_a>(),
1326 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1327 b_scale_thread_vec_up
1328 .template AsType<mfma_scale_input_type_b>(),
1329 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1330 });
1331 });
1332 });
1333 });
1334 });
1335 });
1336 }
1337 }
1338
1339 // TODO: make this field protected when a_scale_thread_copy_ is moved
1340 // here
1341 static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1343 Number<KRepeat / KXdlPack>{},
1345
1346 // TODO: make this field protected when b_scale_thread_copy_ is moved
1347 // here
1348 static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1350 Number<KRepeat / KXdlPack>{},
1352
1353 protected:
1354 using Base::a_thread_copy_;
1355 using Base::a_thread_desc_;
1356 using Base::b_thread_copy_;
1357 using Base::b_thread_desc_;
1358 using Base::c_thread_desc_;
1359};
1360
1361} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:33
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, BBlockBuffer &b_block_buf_up, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp:389
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_bufs, BBlockBuffer &b_block_bufs_up, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:367
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10