block_softmax_2d.hpp Source File

block_softmax_2d.hpp Source File#

Composable Kernel: block_softmax_2d.hpp Source File
block_softmax_2d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9#define _BLOCK_SOFTMAX_USE_UNPACK2 0
10
11namespace ck_tile {
12
13/*
14simple 2d softmax implementation, along row (dim=1)
15requirement:
16 1). each row is within a warp
17 2). data type must be a dword
18*/
19template <typename Problem_, typename Policy_ = void>
21{
24
25 using DataType = typename Problem::DataType;
26
27 template <typename DistributedTensor, index_t dim = 1>
29 operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
30 {
31 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
32 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
33#if _BLOCK_SOFTMAX_USE_UNPACK2
34 const auto f_max3 = [](auto e0, auto e1, auto e2) {
35 float rtn;
36 asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
37 return rtn;
38 };
39 const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
40#endif
41
42 // compute row max
43 auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
44#if _BLOCK_SOFTMAX_USE_UNPACK2
45 auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
46#else
47 auto row_max = reduce_row_max(f_max);
48#endif
49 sweep_tile<DistributedTensor>([&](auto idx) {
50 constexpr auto row_id = make_tuple(idx[number<0>{}]);
51 y(idx) = exp(x[idx] - row_max[row_id]);
52 });
53
54 // compute row sum
55 auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
56#if _BLOCK_SOFTMAX_USE_UNPACK2
57 auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
58#else
59 auto row_sum = reduce_row_sum(f_sum);
60#endif
61 // reciprocal
62 auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
63 sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
64
65 // scale
66 sweep_tile<DistributedTensor>([&](auto idx) {
67 constexpr auto row_id = make_tuple(idx[number<0>{}]);
68 y(idx) = y(idx) * r(row_id);
69 });
70 }
71
72 template <typename DistributedTensor, index_t dim = 1>
73 CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
74 {
75 auto y = DistributedTensor{}; // distributed tensor
76 operator()(x, y, number<dim>{});
77 return y;
78 }
79};
80
81} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
Definition block_softmax_2d.hpp:21
remove_cvref_t< Problem_ > Problem
Definition block_softmax_2d.hpp:22
CK_TILE_DEVICE void operator()(const DistributedTensor &x, DistributedTensor &y, number< dim >={})
Definition block_softmax_2d.hpp:29
remove_cvref_t< Policy_ > Policy
Definition block_softmax_2d.hpp:23
typename Problem::DataType DataType
Definition block_softmax_2d.hpp:25
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38