9#if defined(__gfx942__) || defined(__gfx950__)
14template <index_t MPerWave, index_t NPerWave>
20 template <
class FloatC>
21 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
23 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
25 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
26 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
33 template <
class FloatC>
34 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
36 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
37 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
41template <index_t MPerWave, index_t NPerWave>
47 template <
class FloatC>
48 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
50 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
51 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
55template <index_t MPerWave, index_t NPerWave>
61 template <
class FloatC>
62 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
64 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
65 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
69template <index_t MPerWave, index_t NPerWave>
75 template <
class FloatC>
76 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
78 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
79 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
83template <index_t MPerWave, index_t NPerWave>
89 template <
class FloatC>
90 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
92 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
93 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
100 template <
class FloatC>
101 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
103 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
105 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
106 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
111template <index_t MPerWave, index_t NPerWave>
117 template <
class FloatC>
120 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
122 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
123 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
130 template <
class FloatC>
133 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
134 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
138template <index_t MPerWave, index_t NPerWave>
144 template <
class FloatC>
147#if defined(__gfx950__)
148 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
149 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
158template <index_t MPerWave, index_t NPerWave>
164 template <
class FloatC>
167#if defined(__gfx950__)
168 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
169 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
178template <index_t MPerWave, index_t NPerWave>
184 template <
class FloatC>
187 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
188 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
192template <index_t MPerWave, index_t NPerWave>
198 template <
class FloatC>
201 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
202 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
206template <index_t MPerWave, index_t NPerWave>
212 template <
class FloatC>
215 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
216 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
220template <index_t MPerWave, index_t NPerWave>
226 template <
class FloatC>
229 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
230 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
237 template <
class FloatC>
240 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
241 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
242 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
243 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
248template <index_t MPerWave, index_t NPerWave>
254 template <
class FloatC>
257#if defined(__gfx950__)
258 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
259 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
268template <index_t MPerWave, index_t NPerWave>
274 template <
class FloatC>
277#if defined(__gfx950__)
278 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
279 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
288template <index_t MPerWave, index_t NPerWave>
294 template <
class FloatC>
297 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
298 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
302template <index_t MPerWave, index_t NPerWave>
308 template <
class FloatC>
311 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
312 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
316template <index_t MPerWave, index_t NPerWave>
322 template <
class FloatC>
325 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
326 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
330template <index_t MPerWave, index_t NPerWave>
336 template <
class FloatC>
339 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
340 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
344template <index_t MPerWave, index_t NPerWave>
350 template <
class FloatC>
353 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
356 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
363template <index_t MPerWave, index_t NPerWave>
369 template <
class FloatC>
372 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
375 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
382template <index_t MPerWave, index_t NPerWave>
388 template <
class FloatC>
391#if defined(__gfx950__)
392 reg_c.template AsType<int32x16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
393 reg_a, reg_b, reg_c.template AsType<int32x16_t>()[
Number<0>{}], 0, 0, 0);
402template <index_t MPerWave, index_t NPerWave>
408 template <
class FloatC>
411#if defined(__gfx950__)
412 reg_c.template AsType<int32x4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
413 reg_a, reg_b, reg_c.template AsType<int32x4_t>()[
Number<0>{}], 0, 0, 0);
422template <index_t MPerWave, index_t NPerWave>
428 template <
class FloatC>
431 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
434 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
441template <index_t MPerWave, index_t NPerWave>
447 template <
class FloatC>
450 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
453 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
460template <index_t MPerWave, index_t NPerWave>
466 template <
class FloatC>
467 __device__
static void Run(
const double& reg_a,
const double& reg_b, FloatC& reg_c)
469#if defined(__gfx90a__) || defined(__gfx94__)
470 reg_c.template AsType<double4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
471 reg_a, reg_b, reg_c.template AsType<double4_t>()[
Number<0>{}], 0, 0, 0);
480template <index_t MPerWave, index_t NPerWave>
492 template <
class FloatC>
493 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
495#if defined(__gfx950__)
496 reg_c.template AsType<float16_t>()(
Number<0>{}) =
497 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
500 reg_c.template AsType<float16_t>()[
Number<0>{}],
514 template <
class FloatC>
515 __device__
static void Run(
const bf8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
517#if defined(__gfx950__)
518 reg_c.template AsType<float16_t>()(
Number<0>{}) =
519 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
522 reg_c.template AsType<float16_t>()[
Number<0>{}],
536 template <
class FloatC>
537 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
539#if defined(__gfx950__)
540 reg_c.template AsType<float16_t>()(
Number<0>{}) =
541 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
544 reg_c.template AsType<float16_t>()[
Number<0>{}],
558 template <
class FloatC>
559 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
561#if defined(__gfx950__)
562 reg_c.template AsType<float16_t>()(
Number<0>{}) =
563 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
566 reg_c.template AsType<float16_t>()[
Number<0>{}],
580 template <
class FloatC>
583#if defined(__gfx950__)
590 reg_c.template AsType<float16_t>()(
Number<0>{}) =
591 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
592 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
593 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
594 reg_c.template AsType<float16_t>()[
Number<0>{}],
608 template <
class FloatC>
611#if defined(__gfx950__)
618 reg_c.template AsType<float16_t>()(
Number<0>{}) =
619 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
620 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
621 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
622 reg_c.template AsType<float16_t>()[
Number<0>{}],
636 template <
class FloatC>
639#if defined(__gfx950__)
646 reg_c.template AsType<float16_t>()(
Number<0>{}) =
647 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
648 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
649 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
650 reg_c.template AsType<float16_t>()[
Number<0>{}],
665template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
668template <index_t OpselA, index_t OpselB>
671 template <
class FloatC>
672 __device__
static void Run(
const f8x32_t& reg_a,
674 const f8x32_t& reg_b,
678#if defined(__gfx950__)
680 reg_c.template AsType<float16_t>()(
Number<0>{}) =
681 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
684 reg_c.template AsType<float16_t>()[
Number<0>{}],
708 template <
class FloatC>
709 __device__
static void Run(
const bf8x32_t& reg_a,
711 const bf8x32_t& reg_b,
715#if defined(__gfx950__)
717 reg_c.template AsType<float16_t>()(
Number<0>{}) =
718 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
721 reg_c.template AsType<float16_t>()[
Number<0>{}],
745 template <
class FloatC>
746 __device__
static void Run(
const bf8x32_t& reg_a,
748 const f8x32_t& reg_b,
752#if defined(__gfx950__)
754 reg_c.template AsType<float16_t>()(
Number<0>{}) =
755 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
758 reg_c.template AsType<float16_t>()[
Number<0>{}],
782 template <
class FloatC>
789#if defined(__gfx950__)
796 reg_c.template AsType<float16_t>()(
Number<0>{}) =
797 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
798 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
799 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
800 reg_c.template AsType<float16_t>()[
Number<0>{}],
816 template <
class FloatC>
823#if defined(__gfx950__)
830 reg_c.template AsType<float16_t>()(
Number<0>{}) =
831 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
832 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
833 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
834 reg_c.template AsType<float16_t>()[
Number<0>{}],
850 template <
class FloatC>
857#if defined(__gfx950__)
864 reg_c.template AsType<float16_t>()(
Number<0>{}) =
865 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
866 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
867 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
868 reg_c.template AsType<float16_t>()[
Number<0>{}],
885template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
888template <index_t OpselA, index_t OpselB>
891 template <
class FloatC>
892 __device__
static void Run(
const f8x32_t& reg_a,
894 const f8x32_t& reg_b,
898#if defined(__gfx950__)
900 reg_c.template AsType<float4_t>()(
Number<0>{}) =
901 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
904 reg_c.template AsType<float4_t>()[
Number<0>{}],
920 template <
class FloatC>
921 __device__
static void Run(
const bf8x32_t& reg_a,
923 const bf8x32_t& reg_b,
927#if defined(__gfx950__)
929 reg_c.template AsType<float4_t>()(
Number<0>{}) =
930 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
933 reg_c.template AsType<float4_t>()[
Number<0>{}],
949 template <
class FloatC>
950 __device__
static void Run(
const f8x32_t& reg_a,
952 const bf8x32_t& reg_b,
956#if defined(__gfx950__)
958 reg_c.template AsType<float4_t>()(
Number<0>{}) =
959 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
962 reg_c.template AsType<float4_t>()[
Number<0>{}],
978 template <
class FloatC>
979 __device__
static void Run(
const bf8x32_t& reg_a,
981 const f8x32_t& reg_b,
985#if defined(__gfx950__)
987 reg_c.template AsType<float4_t>()(
Number<0>{}) =
988 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
991 reg_c.template AsType<float4_t>()[
Number<0>{}],
1007 template <
class FloatC>
1014#if defined(__gfx950__)
1020 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1021 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1022 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1023 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1024 reg_c.template AsType<float4_t>()[
Number<0>{}],
1040 template <
class FloatC>
1047#if defined(__gfx950__)
1050 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1051 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1052 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1053 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1054 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1055 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1059 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1060 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1061 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1062 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1063 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1064 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1068 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1069 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1072 reg_c.template AsType<float4_t>()[
Number<0>{}],
1088 template <
class FloatC>
1095#if defined(__gfx950__)
1101 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1102 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1103 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1104 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1105 reg_c.template AsType<float4_t>()[
Number<0>{}],
1121 template <
class FloatC>
1128#if defined(__gfx950__)
1131 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1132 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1133 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1134 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1135 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1136 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1140 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1141 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1142 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1143 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1144 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1145 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1149 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1150 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1153 reg_c.template AsType<float4_t>()[
Number<0>{}],
1169 template <
class FloatC>
1176#if defined(__gfx950__)
1180 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1181 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1182 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1183 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1184 reg_c.template AsType<float4_t>()[
Number<0>{}],
1201template <index_t MPerWave, index_t NPerWave>
1213 template <
class FloatC>
1214 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1216#if defined(__gfx950__)
1218 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1219 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1222 reg_c.template AsType<float4_t>()[
Number<0>{}],
1236 template <
class FloatC>
1237 __device__
static void Run(
const bf8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
1239#if defined(__gfx950__)
1241 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1242 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1245 reg_c.template AsType<float4_t>()[
Number<0>{}],
1259 template <
class FloatC>
1260 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1262#if defined(__gfx950__)
1264 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1265 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1268 reg_c.template AsType<float4_t>()[
Number<0>{}],
1282 template <
class FloatC>
1283 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
1285#if defined(__gfx950__)
1287 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1288 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1291 reg_c.template AsType<float4_t>()[
Number<0>{}],
1305 template <
class FloatC>
1308#if defined(__gfx950__)
1314 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1315 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1316 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1317 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1318 reg_c.template AsType<float4_t>()[
Number<0>{}],
1332 template <
class FloatC>
1335#if defined(__gfx950__)
1341 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1342 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1343 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1344 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1345 reg_c.template AsType<float4_t>()[
Number<0>{}],
1359 template <
class FloatC>
1362#if defined(__gfx950__)
1368 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1369 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1370 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1371 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1372 reg_c.template AsType<float4_t>()[
Number<0>{}],
1387template <index_t MPerWave, index_t NPerWave>
1393 template <
class FloatC>
1394 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1396#if defined(__gfx94__)
1397 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1398 __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1401 reg_c.template AsType<float16_t>()[
Number<0>{}],
1419template <index_t MPerWave, index_t NPerWave>
1425 template <
class FloatC>
1426 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1428#if defined(__gfx94__)
1429 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1432 reg_c.template AsType<float4_t>()[
Number<0>{}],
1450template <index_t MPerWave, index_t NPerWave>
1456 template <
class FloatC>
1457 __device__
static void Run(
const bf8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1459#if defined(__gfx94__)
1460 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1461 __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1464 reg_c.template AsType<float16_t>()[
Number<0>{}],
1482template <index_t MPerWave, index_t NPerWave>
1488 template <
class FloatC>
1489 __device__
static void Run(
const bf8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1491#if defined(__gfx94__)
1492 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1495 reg_c.template AsType<float4_t>()[
Number<0>{}],
1513template <index_t MPerWave, index_t NPerWave>
1519 template <
class FloatC>
1520 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1522#if defined(__gfx94__)
1523 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1524 __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1527 reg_c.template AsType<float16_t>()[
Number<0>{}],
1545template <index_t MPerWave, index_t NPerWave>
1551 template <
class FloatC>
1552 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1554#if defined(__gfx94__)
1555 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1558 reg_c.template AsType<float4_t>()[
Number<0>{}],
1576template <index_t MPerWave, index_t NPerWave>
1582 template <
class FloatC>
1583 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1585#if defined(__gfx94__)
1586 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1587 __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1590 reg_c.template AsType<float16_t>()[
Number<0>{}],
1608template <index_t MPerWave, index_t NPerWave>
1614 template <
class FloatC>
1615 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1617#if defined(__gfx94__)
1618 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1621 reg_c.template AsType<float4_t>()[
Number<0>{}],
1640template <index_t MPerWave, index_t NPerWave>
1646 template <
class FloatC>
1649#if defined(__gfx94__)
1650 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
1651 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
1660template <index_t MPerWave, index_t NPerWave>
1666 template <
class FloatC>
1669#if defined(__gfx94__)
1670 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
1671 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< int32_t, 4 >::type int32x4_t
Definition dtype_vector.hpp:2168
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
typename vector_type< int32_t, 8 >::type int32x8_t
Definition dtype_vector.hpp:2170
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition dtype_vector.hpp:2161
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition dtype_vector.hpp:2273
typename vector_type< int32_t, 6 >::type int32x6_t
Definition dtype_vector.hpp:2169
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition dtype_vector.hpp:2268
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition dtype_vector.hpp:2160
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition dtype_vector.hpp:2272
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition dtype_vector.hpp:2262
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition dtype_vector.hpp:2267
signed int int32_t
Definition stdint.h:123
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1360
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1333
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1283
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1237
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1260
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1306
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1214
Definition amd_xdlops.hpp:1202
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:309
Definition amd_xdlops.hpp:303
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:199
Definition amd_xdlops.hpp:193
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:76
Definition amd_xdlops.hpp:70
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:275
Definition amd_xdlops.hpp:269
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1489
Definition amd_xdlops.hpp:1483
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1615
Definition amd_xdlops.hpp:1609
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:165
Definition amd_xdlops.hpp:159
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1552
Definition amd_xdlops.hpp:1546
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1426
Definition amd_xdlops.hpp:1420
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:213
Definition amd_xdlops.hpp:207
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:62
Definition amd_xdlops.hpp:56
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:337
Definition amd_xdlops.hpp:331
static __device__ void Run(const float2_t ®_a, const float2_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1647
Definition amd_xdlops.hpp:1641
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:255
Definition amd_xdlops.hpp:249
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1457
Definition amd_xdlops.hpp:1451
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1583
Definition amd_xdlops.hpp:1577
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:145
Definition amd_xdlops.hpp:139
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1520
Definition amd_xdlops.hpp:1514
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1394
Definition amd_xdlops.hpp:1388
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:34
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:21
Definition amd_xdlops.hpp:15
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:48
Definition amd_xdlops.hpp:42
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:323
Definition amd_xdlops.hpp:317
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:131
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:118
Definition amd_xdlops.hpp:112
static __device__ void Run(const float2_t ®_a, const float2_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:1667
Definition amd_xdlops.hpp:1661
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:537
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:559
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:637
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:609
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:493
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:581
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:515
Definition amd_xdlops.hpp:481
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:295
Definition amd_xdlops.hpp:289
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:185
Definition amd_xdlops.hpp:179
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:90
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:101
Definition amd_xdlops.hpp:84
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:227
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:238
Definition amd_xdlops.hpp:221
static __device__ void Run(const double ®_a, const double ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:467
Definition amd_xdlops.hpp:461
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:370
Definition amd_xdlops.hpp:364
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:448
Definition amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:409
Definition amd_xdlops.hpp:403
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:429
Definition amd_xdlops.hpp:423
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:389
Definition amd_xdlops.hpp:383
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition amd_xdlops.hpp:351
Definition amd_xdlops.hpp:345
static __device__ void Run(const f6x16x2_t ®_a, const int32_t scale_a, const f6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:1041
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:1170
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:1008
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:892
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:979
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:950
static __device__ void Run(const bf6x16x2_t ®_a, const int32_t scale_a, const bf6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:1122
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:1089
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:921
Definition amd_xdlops.hpp:886
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:672
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:783
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:709
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:746
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:851
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition amd_xdlops.hpp:817
Definition amd_xdlops.hpp:666
Definition functional2.hpp:33
Definition dtype_vector.hpp:10