topk_softmax_warp_per_row_pipeline.hpp Source File

topk_softmax_warp_per_row_pipeline.hpp Source File#

Composable Kernel: topk_softmax_warp_per_row_pipeline.hpp Source File
topk_softmax_warp_per_row_pipeline.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
12#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
13#endif
14
15namespace ck_tile {
16
17template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
19{
20 // TODO: this kernel only support warp per row
23 using WeightType = typename Problem::WeightType;
24
25 template <typename InputWindow, typename OutputWindow, typename IndexWindow>
26 CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
27 OutputWindow& out_window,
28 IndexWindow& idx_window,
29 index_t rows,
30 index_t experts,
31 index_t k,
32 index_t block_row_id)
33 {
34#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
35 auto inp_win = make_tile_window_linear_raw(
36 input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
37#else
38 auto inp_win = make_tile_window_linear(
39 input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
40#endif
41 auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
42 out_window.get_window_lengths(),
43 out_window.get_window_origin(),
44 Policy::template MakeOutputDistribution<Problem>());
45 auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
46 idx_window.get_window_lengths(),
47 idx_window.get_window_origin(),
48 Policy::template MakeOutputDistribution<Problem>());
49
50 auto softmax = Policy::template GetSoftmax<Problem>();
51 auto topk = Policy::template GetTopk<Problem>();
52
53 const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
54
55 while(1)
56 {
57#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
58 __builtin_amdgcn_sched_barrier(0);
59 auto x =
62 __builtin_amdgcn_sched_barrier(0);
63#else
64 auto x = load_tile(inp_win);
65#endif
66 // cast and pad input data
67 auto w = [&]() {
68#if 0
69 auto w_ = cast_tile<WeightType>(x);
70
71 constexpr auto span_2d = decltype(w_)::get_distributed_spans();
72 sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
73 sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
74 constexpr auto i_j_idx = make_tuple(idx0, idx1);
75 const auto x_indices = get_x_indices_from_distributed_indices(
76 w_.get_tile_distribution(), i_j_idx);
77 const auto current_expert = x_indices.at(number<1>{});
78 // set to -INF if OOB so that later softmax can work properly
79 w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
80 : w_(i_j_idx);
81 });
82 });
83 return w_;
84#else
85 auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
86 auto w_f = [&](auto idx) {
87 w_(idx) = type_convert<WeightType>(x(idx));
88 const auto x_indices =
89 get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
90 const auto current_expert = x_indices.at(number<1>{});
91 w_(idx) =
92 current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
93 if constexpr(!Problem::ActivationIsSoftmax)
94 {
95 // sigmoid can be pre-computed already here if not using softmax
96 w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
97 }
98 };
99 tile_sweeper ts{w_, w_f};
100 ts();
101 return w_;
102#endif
103 }();
104
105 if constexpr(Problem::ActivationIsSoftmax)
106 {
107 auto y = softmax(w);
108 topk(y, out_win, idx_win, k);
109 }
110 else
111 {
112 // sigmoid was already pre-computed above, so only do topk now
113 topk(w, out_win, idx_win, k);
114 }
115
116 // check exit
117 if constexpr(Problem::LaunchType == 0)
118 {
119 break;
120 }
121 else
122 {
123 block_row_id += grid_rows_per_loop;
124 if(block_row_id >= rows)
125 break;
126 }
127
128 move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
129 move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
130 move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
131 }
132 }
133};
134} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:1029
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition load_tile.hpp:81
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:815
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition topk_softmax_warp_per_row_pipeline.hpp:19
remove_cvref_t< Problem_ > Problem
Definition topk_softmax_warp_per_row_pipeline.hpp:21
CK_TILE_DEVICE auto operator()(const InputWindow &input_window, OutputWindow &out_window, IndexWindow &idx_window, index_t rows, index_t experts, index_t k, index_t block_row_id)
Definition topk_softmax_warp_per_row_pipeline.hpp:26
typename Problem::WeightType WeightType
Definition topk_softmax_warp_per_row_pipeline.hpp:23
remove_cvref_t< Policy_ > Policy
Definition topk_softmax_warp_per_row_pipeline.hpp:22
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition sweep_tile.hpp:260