multi_index_transform_helper.hpp Source File

multi_index_transform_helper.hpp Source File#

Composable Kernel: multi_index_transform_helper.hpp Source File
multi_index_transform_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
11template <typename LowLength>
12__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
13{
14 return PassThrough<LowLength>{low_length};
15}
16
17template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
18__host__ __device__ constexpr auto
19make_pad_transform(const LowLength& low_length,
20 const LeftPad& left_pad,
21 const RightPad& right_pad,
23{
24 return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
25}
26
27template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
28__host__ __device__ constexpr auto make_left_pad_transform(
29 const LowLength& low_length,
30 const LeftPadLength& left_pad,
32{
33 return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
34}
35
36template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
37__host__ __device__ constexpr auto make_right_pad_transform(
38 const LowLength& low_length,
39 const RightPadLength& right_pad,
41{
42 return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
43}
44
45template <typename UpLengths,
46 typename Coefficients,
47 typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
48__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
49 const Coefficients& coefficients)
50{
51 return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
52}
53
54template <typename LowLengths>
55__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
56{
57#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
59#else
60 return make_merge_transform_v1_carry_check(low_lengths);
61#endif
62}
63
64template <typename LowLengths>
65__host__ __device__ constexpr auto
66make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
67{
68 return Merge_v1_carry_check<LowLengths>{low_lengths};
69}
70
71template <typename LowLengths>
72__host__ __device__ constexpr auto
73make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
74{
75#if 1
76 return Merge_v2_magic_division<LowLengths>{low_lengths};
77#else
78 return Merge_v2r2_magic_division<LowLengths>{low_lengths};
79#endif
80}
81
82template <typename LowLengths>
83__host__ __device__ constexpr auto
84make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
85{
86 return Merge_v3_division_mod<LowLengths>{low_lengths};
87}
88
89template <typename UpLengths, bool Use24BitIntegerCalculation = false>
90__host__ __device__ constexpr auto make_unmerge_transform(
91 const UpLengths& up_lengths,
93{
94 return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
95}
96
97__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
98 index_t Ho,
99 index_t Wo,
100 index_t K,
101 [[maybe_unused]] index_t YDot,
102 index_t XDot,
103 index_t HTilde,
104 index_t WTilde,
105 index_t ConvDilationH,
106 index_t ConvDilationW,
107 index_t HTildeSlice,
108 index_t WTildeSlice,
109 index_t YDotSlice,
110 index_t XDotSlice,
111 index_t IHTildeSliceBegin,
112 index_t IWTildeSliceBegin,
113 index_t GcdStrideDilationH,
114 index_t GcdStrideDilationW,
115 index_t K0,
116 index_t K1,
117 index_t MPerBlock,
118 index_t GemmKPerBlock)
119{
120 // Calculate padding
121 const auto MRaw = N * HTildeSlice * WTildeSlice;
122 const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
123 const auto MPad = MPadded - MRaw;
124
125 const auto KRaw = YDotSlice * XDotSlice * K;
126 const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
127 const auto KPad = KPadded - KRaw;
128
130 Ho,
131 Wo,
132 K,
133 XDot,
134 HTilde,
135 WTilde,
136 WTildeSlice,
137 HTildeSlice * WTildeSlice,
138 IHTildeSliceBegin,
139 IWTildeSliceBegin,
140 -ConvDilationH / GcdStrideDilationH,
141 -ConvDilationW / GcdStrideDilationW,
142 XDotSlice * K,
143 K0,
144 MPadded,
145 K1,
146 MPad,
147 KPad};
148}
149
150template <typename LowerIndex>
151__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
152{
153 return Freeze<LowerIndex>{low_idx};
154}
155
156template <typename UpperIndex>
157__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
158{
159 return Insert<UpperIndex>{up_idx};
160}
161
162template <typename LowLength, typename SliceBegin, typename SliceEnd>
163__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
164 const SliceBegin& slice_begin,
165 const SliceEnd& slice_end)
166{
167 return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
168}
169
170template <typename VectorSize, typename UpLength>
171__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
172 const UpLength& up_length)
173{
174 return Vectorize<VectorSize, UpLength>{vector_size, up_length};
175}
176
177template <typename Modulus, typename UpLength>
178__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
179 const UpLength& up_length)
180{
181 return Modulo<Modulus, UpLength>{modulus, up_length};
182}
183
184template <typename LowLengths>
185__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths)
186{
187 return Xor<LowLengths, true /*ApplyModulo*/>{low_lengths};
188}
189
190template <typename LowLengths>
191__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
192{
193 return Xor<LowLengths, false /*ApplyModulo*/>{low_lengths};
194}
195} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex &up_idx)
Definition multi_index_transform_helper.hpp:157
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize &vector_size, const UpLength &up_length)
Definition multi_index_transform_helper.hpp:171
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:73
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, index_t Ho, index_t Wo, index_t K, index_t YDot, index_t XDot, index_t HTilde, index_t WTilde, index_t ConvDilationH, index_t ConvDilationW, index_t HTildeSlice, index_t WTildeSlice, index_t YDotSlice, index_t XDotSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t GcdStrideDilationH, index_t GcdStrideDilationW, index_t K0, index_t K1, index_t MPerBlock, index_t GemmKPerBlock)
Definition multi_index_transform_helper.hpp:97
__host__ __device__ constexpr auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:28
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:191
__host__ __device__ constexpr auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition multi_index_transform_helper.hpp:178
__host__ __device__ constexpr auto make_merge_transform_v1_carry_check(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:66
Transformation struct for convolution backward data output indices to GEMM indices.
Definition multi_index_transform.hpp:1565
Definition multi_index_transform.hpp:385
Definition multi_index_transform.hpp:1750
Definition multi_index_transform.hpp:1816
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:481
Definition multi_index_transform.hpp:1036
Definition multi_index_transform.hpp:1188
Definition multi_index_transform.hpp:1338
Definition multi_index_transform.hpp:2065
Definition multi_index_transform.hpp:13
Definition multi_index_transform.hpp:284
Definition multi_index_transform.hpp:1968
Definition multi_index_transform.hpp:1882
Definition multi_index_transform.hpp:2149
Definition utility/integral_constant.hpp:20