blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::KGroup;
126 using Base::KRepeat;
127 using Base::xdlops_gemm;
128 using typename Base::HotLoopInstList;
129
142
143 using Base::AMmaKStride;
144 using Base::BMmaKStride;
146 using Base::MWaves;
147 using Base::WaveSize;
148
149 static constexpr index_t PrefetchStages = 2;
150 static constexpr index_t PrefillStages = 1;
151 static constexpr index_t GlobalBufferNum = 2;
152
153 template <typename TileDesc_M0_M1_M2_K>
154 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
155 {
156 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
157 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
158 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
159 constexpr index_t K2 = KPack / KGroup;
160 constexpr index_t K1 = WaveSize / NPerXDL;
161 constexpr index_t K0 = KRepeat * KGroup;
162
164 TileDesc_M0_M1_M2_K{},
172 }
173
174 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
176
177 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
178 {
179 return num_loop > PrefetchStages;
180 }
181
182 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
183 {
184 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
185 }
186
187 __device__ static constexpr auto HotLoopScheduler()
188 {
189 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
190 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
191 constexpr auto num_buffer_load_inst_b =
193 constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
194 // B global
196 ignore = i;
197 if constexpr(MPerBlock >= 128 && NPerBlock >= 64)
198 {
199 __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
200 }
201 else
202 {
203 __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
204 }
205 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
206 // if constexpr(i.value < num_buffer_load_inst_a) {
207 // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
208 // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
209 // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
210 // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
211 // }
212 });
213
214 // A global
216 ignore = i;
217 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
218 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
219 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
220 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
221 });
222
223 // A local
224 static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
225 [&](auto i) {
226 ignore = i;
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
229 });
230 }
231
232 template <bool HasMainLoop,
233 TailNumber TailNum,
234 typename AGridDesc,
235 typename ABlockDesc,
236 typename ABlockTransfer,
237 typename AGridBuffer,
238 typename ABlockBuffer,
239 typename ABlockTransferStep,
240 typename BGridDesc,
241 typename BBlockTransfer,
242 typename BGridBuffer,
243 typename BBlockBuffer,
244 typename BBlockTransferStep,
245 typename CThreadBuffer>
246 __device__ void Run(const AGridDesc& a_grid_desc,
247 const ABlockDesc& a_block_desc,
248 ABlockTransfer& a_blockwise_copy,
249 const AGridBuffer& a_grid_buf,
250 ABlockBuffer& a_block_buf,
251 const ABlockTransferStep& a_block_copy_step,
252 const BGridDesc& b_grid_desc,
253 BBlockTransfer& b_blockwise_copy,
254 BBlockTransfer& b_blockwise_copy_up,
255 const BGridBuffer& b_grid_buf,
256 const BGridBuffer& b_grid_buf_up,
257 BBlockBuffer& b_block_buf,
258 const BBlockTransferStep& b_block_copy_step,
259 CThreadBuffer& c_thread_buf,
260 CThreadBuffer& c_thread_buf_up,
261 index_t num_loop) const
262 {
263 ignore = b_block_buf;
264 __builtin_amdgcn_sched_barrier(0);
266 a_thread_desc_.GetElementSpaceSize());
268 b_thread_desc_.GetElementSpaceSize());
269
270 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
271 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
272 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
273
274 // Global prefetch A1 B1
275 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
276 b_blockwise_copy.Run(b_grid_desc,
277 b_grid_buf,
279 b_block_origin_idx,
280 b_thread_bufs(I0));
281 b_blockwise_copy_up.Run(b_grid_desc,
282 b_grid_buf_up,
284 b_block_origin_idx,
285 b_thread_bufs_up(I0));
286
287 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
288 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
289 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
290 __builtin_amdgcn_sched_barrier(0);
291
292 // // Local prefill A1
293 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
294
295 // // Global prefetch A2
296 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
297 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
298
299 // Local prefetch A1
301 static_for<0, MRepeat, 1>{}([&](auto m0) {
302 static_for<0, KRepeat, 1>{}([&](auto k0) {
303 static_for<0, KGroup, 1>{}([&](auto kg0) {
306 a_block_buf,
308 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
309 a_thread_buf);
310 });
311 });
312 });
313
314 // Initialize C
315 c_thread_buf.Clear();
316 c_thread_buf_up.Clear();
317
318 __builtin_amdgcn_sched_barrier(0);
319
320 // main body
321 if constexpr(HasMainLoop)
322 {
323 index_t i = 0;
324 do
325 {
326 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
327 b_blockwise_copy.Run(b_grid_desc,
328 b_grid_buf,
330 b_block_origin_idx,
331 b_thread_bufs(local_read_buf));
332 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
333
334 b_blockwise_copy_up.Run(b_grid_desc,
335 b_grid_buf_up,
337 b_block_origin_idx,
338 b_thread_bufs_up(local_read_buf));
339 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
341 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
342
343 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
344 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
345 static_for<0, MRepeat, 1>{}([&](auto m0) {
346 static_for<0, NRepeat, 1>{}([&](auto n0) {
347 static_for<0, KRepeat, 1>{}([&](auto k0) {
351
352 static_for<0, KPack, 1>{}([&](auto ik) {
353 a_thread_vec.template AsType<ComputeDataType>()(ik) =
354 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
355 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
356 b_thread_vec.template AsType<ComputeDataType>()(ik) =
357 b_thread_bufs[mfma_reg_buf]
358 [Number<b_thread_desc_.CalculateOffset(
359 make_tuple(n0, I0, k0, ik))>{}];
360 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
361 b_thread_bufs_up[mfma_reg_buf]
362 [Number<b_thread_desc_.CalculateOffset(
363 make_tuple(n0, I0, k0, ik))>{}];
364 });
365 using mfma_input_type =
366 typename vector_type<ComputeDataType,
367 xdlops_gemm.K1PerXdlops>::type;
368
369 constexpr index_t c_offset =
370 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
371
372 xdlops_gemm.Run(
373 a_thread_vec.template AsType<mfma_input_type>(),
374 b_thread_vec.template AsType<mfma_input_type>(),
375 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
376
377 xdlops_gemm.Run(
378 a_thread_vec.template AsType<mfma_input_type>(),
379 b_thread_vec_up.template AsType<mfma_input_type>(),
380 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
381 });
382 });
383 });
384
386
387 static_for<0, MRepeat, 1>{}([&](auto m0) {
388 static_for<0, KRepeat, 1>{}([&](auto k0) {
389 static_for<0, KGroup, 1>{}([&](auto kg0) {
390 a_thread_copy_.Run(
393 a_block_buf,
395 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
396 a_thread_buf);
397 });
398 });
399 });
400
402 __builtin_amdgcn_sched_barrier(0);
403 };
404
405 LoopFunc(I0, I1);
406 LoopFunc(I1, I0);
407
408 i += 2;
409 } while(i < (num_loop - 2));
410 }
411 // tail
412 if constexpr(TailNum == TailNumber::Even)
413 {
414 b_blockwise_copy.Run(b_grid_desc,
415 b_grid_buf,
417 b_block_origin_idx,
418 b_thread_bufs(I1));
419
420 b_blockwise_copy_up.Run(b_grid_desc,
421 b_grid_buf_up,
423 b_block_origin_idx,
424 b_thread_bufs_up(I1));
426 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
427
428 static_for<0, MRepeat, 1>{}([&](auto m0) {
429 static_for<0, NRepeat, 1>{}([&](auto n0) {
430 static_for<0, KRepeat, 1>{}([&](auto k0) {
434
435 static_for<0, KPack, 1>{}([&](auto ik) {
436 a_thread_vec.template AsType<ComputeDataType>()(ik) =
437 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
438 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
439 b_thread_vec.template AsType<ComputeDataType>()(ik) =
440 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
441 make_tuple(n0, I0, k0, ik))>{}];
442 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
443 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
444 make_tuple(n0, I0, k0, ik))>{}];
445 });
446
447 using mfma_input_type =
448 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
449
450 constexpr index_t c_offset =
451 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
452
453 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
454 b_thread_vec.template AsType<mfma_input_type>(),
455 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
456
457 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
458 b_thread_vec_up.template AsType<mfma_input_type>(),
459 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
460 });
461 });
462 });
463
465
466 static_for<0, MRepeat, 1>{}([&](auto m0) {
467 static_for<0, KRepeat, 1>{}([&](auto k0) {
468 static_for<0, KGroup, 1>{}([&](auto kg0) {
469 a_thread_copy_.Run(
472 a_block_buf,
474 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
475 a_thread_buf);
476 });
477 });
478 });
479
480 __builtin_amdgcn_sched_barrier(0);
481
482 static_for<0, MRepeat, 1>{}([&](auto m0) {
483 static_for<0, NRepeat, 1>{}([&](auto n0) {
484 static_for<0, KRepeat, 1>{}([&](auto k0) {
488
489 static_for<0, KPack, 1>{}([&](auto ik) {
490 a_thread_vec.template AsType<ComputeDataType>()(ik) =
491 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
492 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
493 b_thread_vec.template AsType<ComputeDataType>()(ik) =
494 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
495 make_tuple(n0, I0, k0, ik))>{}];
496 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
497 b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
498 make_tuple(n0, I0, k0, ik))>{}];
499 });
500
501 using mfma_input_type =
502 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
503
504 constexpr index_t c_offset =
505 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
506
507 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
508 b_thread_vec.template AsType<mfma_input_type>(),
509 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
510 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
511 b_thread_vec_up.template AsType<mfma_input_type>(),
512 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
513 });
514 });
515 });
516 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
517 // latency
518 // __builtin_amdgcn_sched_barrier(0);
519 }
520 else if constexpr(TailNum == TailNumber::Odd)
521 {
522 static_for<0, MRepeat, 1>{}([&](auto m0) {
523 static_for<0, NRepeat, 1>{}([&](auto n0) {
524 static_for<0, KRepeat, 1>{}([&](auto k0) {
528
529 static_for<0, KPack, 1>{}([&](auto ik) {
530 a_thread_vec.template AsType<ComputeDataType>()(ik) =
531 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
532 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
533 b_thread_vec.template AsType<ComputeDataType>()(ik) =
534 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
535 make_tuple(n0, I0, k0, ik))>{}];
536 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
537 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
538 make_tuple(n0, I0, k0, ik))>{}];
539 });
540
541 using mfma_input_type =
542 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
543
544 constexpr index_t c_offset =
545 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
546
547 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
548 b_thread_vec.template AsType<mfma_input_type>(),
549 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
550 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
551 b_thread_vec_up.template AsType<mfma_input_type>(),
552 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
553 });
554 });
555 });
556 }
557 }
558
559 protected:
560 // MRepeat MWave MLane KRepeat KLane KPack
561 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
564
566 ComputeDataType,
568 decltype(a_thread_desc_),
569 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
571 5,
572 A_K1,
573 A_K1>;
574
576
579
580 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
581};
582
583} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops_base.hpp:44
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr index_t KGroup
Definition blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp:246
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp:102
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp:565
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10