add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp Source File

add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp Source File#

Composable Kernel: add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp Source File
add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
15{
18
26
27 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
28 static constexpr bool kSaveX = Problem::kSaveX;
29
30 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
31 static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
32 static constexpr bool kPadN = Problem::kPadN;
33 static constexpr bool UseMax3 = true; // TODO - Move to trait
34
35 static constexpr const char* name = []() {
36 if constexpr(kNeedCrossWarpSync)
37 return "bpr_tp"; // block per row
38 else
39 return "wpr_tp"; // warp per row
40 }();
41
43 {
44 return Policy::template GetSmemSize<Problem>();
45 }
46
47 template <typename AWindow,
48 typename BWindow,
49 typename GammaWindow,
50 typename XWindow,
51 typename YScaleWindow,
52 typename QYWindow>
53 CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
54 const BWindow& b_window_,
55 const GammaWindow& gamma_window_,
56 XWindow& x_window_,
57 YScaleWindow& yscale_window,
58 QYWindow& qy_window,
59 ComputeDataType epsilon,
60 ck_tile::index_t row_size,
61 void* smem) const
62 {
63 auto a_window =
64 make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
65 auto b_window =
66 make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
67 auto x_window = [&]() {
68 if constexpr(kSaveX)
69 return make_tile_window(x_window_,
70 Policy::template MakeABXBlockTileDistribution<Problem>());
71 else
72 return x_window_;
73 }();
74 auto gamma_window = make_tile_window(
75 gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
76
77 auto reduce_square_sum_func = ReduceOp::SquareAdd{};
78 auto reduce_sum_func = ReduceOp::Add{};
79 auto reduce_absmax_func = ReduceOp::AbsMax{};
80 auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
81 float rtn;
82 asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
83 : "=v"(rtn)
84 : "v"(acc_), "v"(v_0_), "v"(v_1_));
85 return rtn;
86 };
87 auto reduce_max_func = ReduceOp::Max{};
88 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
89 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
90 auto block_reduce2d_cross_warp_sync =
91 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
92
93 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
94 index_t num_n_tile_iteration =
96
97 using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(a_window)));
98 auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
99 set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
100
101 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
102 {
103 const auto a = load_tile(a_window);
104 const auto b = load_tile(b_window);
105
106 auto x = tile_elementwise_in(
107 [&](const auto& a_, const auto& b_) {
109 },
110 a,
111 b);
112
113 if constexpr(kSaveX)
114 store_tile(x_window, cast_tile<XDataType>(x));
115
116 block_reduce2d(x, square_sum, reduce_square_sum_func);
117 move_tile_window(x_window, {0, Block_N});
118 move_tile_window(a_window, {0, Block_N});
119 move_tile_window(b_window, {0, Block_N});
120 }
121
122 block_reduce2d_sync(square_sum, reduce_sum_func);
123 block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
124
125 auto inv_rms = tile_elementwise_in(
126 [&](const auto& v_) {
127 return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
128 },
129 square_sum);
130
131 // reverse read x to reuse cache
132 ck_tile::index_t stride_to_right_most_window =
133 row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
134
135 if constexpr(kSaveX)
136 move_tile_window(x_window, {0, -Block_N});
137 else
138 {
139 move_tile_window(a_window, {0, -Block_N});
140 move_tile_window(b_window, {0, -Block_N});
141 }
142 move_tile_window(gamma_window, {stride_to_right_most_window});
143
144 using YTensorType = XTensorType;
145 auto absmax = block_reduce2d.template MakeYBlockTile<YTensorType>();
146 set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
147
148 // rmsnorm computation + absmax(threadwise reduce)
149 if constexpr(kSaveX)
150 __syncthreads();
151
152 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
153 {
154 auto x = [&]() {
155 if constexpr(kSaveX)
156 {
157 return load_tile(x_window);
158 }
159 else
160 {
161 const auto a = load_tile(a_window);
162 const auto b = load_tile(b_window);
163 return tile_elementwise_in(
164 [&](const auto& a_, const auto& b_) {
167 },
168 a,
169 b);
170 }
171 }();
172
173 auto gamma = load_tile(gamma_window);
174 auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
175
176 sweep_tile(y, [&](auto idx) {
177 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
178 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
179
180 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
181
182 const auto x_ = type_convert<ComputeDataType>(x[idx]);
183 auto y_ = x_ * inv_rms[i_idx] * gamma_;
184
186 });
187
188 constexpr auto x_size_per_row =
189 x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
190 if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
191 x_size_per_row % 2 == 0)
192 block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
193 else
194 block_reduce2d(y, absmax, reduce_absmax_func);
195
196 if constexpr(kSaveX)
197 move_tile_window(x_window, {0, -Block_N});
198 else
199 {
200 move_tile_window(a_window, {0, -Block_N});
201 move_tile_window(b_window, {0, -Block_N});
202 }
203 move_tile_window(gamma_window, {-Block_N});
204 }
205
206 // compute absmax, cross-lane->cross-warp
207 block_reduce2d_sync(absmax, reduce_max_func);
208 block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
209
210 // ex: yscale = absmax / 127 if int8
211 auto yscale = tile_elementwise_in(
212 [&](const auto& v_) {
214 },
215 absmax);
216 store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
217
218 // quantize y to qy
219 // recompute rmsnorm, try to save y in the future
220 if constexpr(kSaveX)
221 move_tile_window(x_window, {0, Block_N});
222 else
223 {
224 move_tile_window(a_window, {0, Block_N});
225 move_tile_window(b_window, {0, Block_N});
226 }
227 move_tile_window(gamma_window, {Block_N});
228
229 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
230 {
231 auto x = [&]() {
232 if constexpr(kSaveX)
233 {
234 return load_tile(x_window);
235 }
236 else
237 {
238 const auto a = load_tile(a_window);
239 const auto b = load_tile(b_window);
240 return tile_elementwise_in(
241 [&](const auto& a_, const auto& b_) {
244 },
245 a,
246 b);
247 }
248 }();
249
250 auto gamma = load_tile(gamma_window);
251 auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
252 auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
253
254 sweep_tile(y, [&](auto idx) {
255 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
256 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
257
258 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
259
260 const auto x_ = type_convert<ComputeDataType>(x[idx]);
261 auto y_ = x_ * inv_rms[i_idx] * gamma_;
262 auto qy_ = y_ / yscale[i_idx];
264 });
265
266 store_tile(qy_window, qy);
267
268 if constexpr(kSaveX)
269 move_tile_window(x_window, {0, Block_N});
270 else
271 {
272 move_tile_window(a_window, {0, Block_N});
273 move_tile_window(b_window, {0, Block_N});
274 }
275 move_tile_window(gamma_window, {Block_N});
276 move_tile_window(qy_window, {0, Block_N});
277 }
278 }
279};
280} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:15
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:22
ck_tile::remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:25
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:42
ck_tile::remove_cvref_t< typename Problem::BDataType > BDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:20
ck_tile::remove_cvref_t< Problem_ > Problem
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:16
ck_tile::remove_cvref_t< Policy_ > Policy
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:17
static constexpr bool kSaveX
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:28
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:35
static constexpr bool UseMax3
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:33
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:21
ck_tile::remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:24
static constexpr bool kPadN
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:32
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const BWindow &b_window_, const GammaWindow &gamma_window_, XWindow &x_window_, YScaleWindow &yscale_window, QYWindow &qy_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem) const
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:53
static constexpr bool kNeedCrossWarpSync
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:30
static constexpr bool kHasGamma
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:27
static constexpr bool kPadM
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:31
ck_tile::remove_cvref_t< typename Problem::ADataType > ADataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:19
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp:23
Definition reduce_operator.hpp:101
Definition reduce_operator.hpp:14
Definition reduce_operator.hpp:65
Definition reduce_operator.hpp:40
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
Definition unary_element_function.hpp:56
Definition tile/core/container/sequence.hpp:49