math.hpp Source File

math.hpp Source File#

Composable Kernel: math.hpp Source File
utility/math.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
8#include "number.hpp"
9#include "type.hpp"
10#include "enable_if.hpp"
11
12namespace ck {
13namespace math {
14
15template <typename T, T s>
16struct scales
17{
18 __host__ __device__ constexpr T operator()(T a) const { return s * a; }
19};
20
21template <typename T>
22struct plus
23{
24 __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
25};
26
27template <typename T>
28struct minus
29{
30 __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
31};
32
34{
35 template <typename A, typename B>
36 __host__ __device__ constexpr auto operator()(const A& a, const B& b) const
37 {
38 return a * b;
39 }
40};
41
42template <typename T>
44{
45 __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
46};
47
48template <typename T>
50{
51 __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
52};
53
54template <typename T>
56{
57 __host__ __device__ constexpr T operator()(T a, T b) const
58 {
59 static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
60
61 return (a + b - Number<1>{}) / b;
62 }
63};
64
65template <typename X, typename Y>
66__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
67{
68 return x / y;
69}
70
71template <typename X, typename Y>
72__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
73{
74 return (x + y - Number<1>{}) / y;
75}
76
77template <typename X, typename Y>
78__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
79{
80 return y * integer_divide_ceil(x, y);
81}
82
83template <typename T>
84__host__ __device__ constexpr T max(T x)
85{
86 return x;
87}
88
89template <typename T>
90__host__ __device__ constexpr T max(T x, T y)
91{
92 return x > y ? x : y;
93}
94
95template <index_t X>
96__host__ __device__ constexpr index_t max(Number<X>, index_t y)
97{
98 return X > y ? X : y;
99}
100
101template <index_t Y>
102__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
103{
104 return x > Y ? x : Y;
105}
106
107template <typename X, typename... Ys>
108__host__ __device__ constexpr auto max(X x, Ys... ys)
109{
110 static_assert(sizeof...(Ys) > 0, "not enough argument");
111
112 return max(x, max(ys...));
113}
114
115template <typename T>
116__host__ __device__ constexpr T min(T x)
117{
118 return x;
119}
120
121template <typename T>
122__host__ __device__ constexpr T min(T x, T y)
123{
124 return x < y ? x : y;
125}
126
127template <index_t X>
128__host__ __device__ constexpr index_t min(Number<X>, index_t y)
129{
130 return X < y ? X : y;
131}
132
133template <index_t Y>
134__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
135{
136 return x < Y ? x : Y;
137}
138
139template <typename X, typename... Ys>
140__host__ __device__ constexpr auto min(X x, Ys... ys)
141{
142 static_assert(sizeof...(Ys) > 0, "not enough argument");
143
144 return min(x, min(ys...));
145}
146
147template <typename T>
148__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
149{
150 return min(max(x, lowerbound), upperbound);
151}
152
153// greatest common divisor, aka highest common factor
154__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
155{
156 if(x < 0)
157 {
158 return gcd(-x, y);
159 }
160 else if(y < 0)
161 {
162 return gcd(x, -y);
163 }
164 else if(x == y || x == 0)
165 {
166 return y;
167 }
168 else if(y == 0)
169 {
170 return x;
171 }
172 else if(x > y)
173 {
174 return gcd(x % y, y);
175 }
176 else
177 {
178 return gcd(x, y % x);
179 }
180}
181
182template <index_t X, index_t Y>
183__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
184{
185 constexpr auto r = gcd(X, Y);
186
187 return Number<r>{};
188}
189
190template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
191__host__ __device__ constexpr auto gcd(X x, Ys... ys)
192{
193 return gcd(x, gcd(ys...));
194}
195
196// least common multiple
197template <typename X, typename Y>
198__host__ __device__ constexpr auto lcm(X x, Y y)
199{
200 return (x * y) / gcd(x, y);
201}
202
203template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
204__host__ __device__ constexpr auto lcm(X x, Ys... ys)
205{
206 return lcm(x, lcm(ys...));
207}
208
209template <typename T>
210struct equal
211{
212 __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
213};
214
215template <typename T>
216struct less
217{
218 __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
219};
220
221template <index_t X>
222__host__ __device__ constexpr auto next_power_of_two()
223{
224 // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
225 constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
226 return Y;
227}
228
229template <index_t X>
230__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
231{
232 // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
233 constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
234 return Number<Y>{};
235}
236
237} // namespace math
238} // namespace ck
Definition utility/math.hpp:13
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto next_power_of_two()
Definition utility/math.hpp:222
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
static constexpr T value
Definition utility/integral_constant.hpp:21
Definition type.hpp:177
Definition utility/math.hpp:211
__host__ __device__ constexpr bool operator()(T x, T y) const
Definition utility/math.hpp:212
Definition utility/math.hpp:56
__host__ __device__ constexpr T operator()(T a, T b) const
Definition utility/math.hpp:57
Definition utility/math.hpp:217
__host__ __device__ constexpr bool operator()(T x, T y) const
Definition utility/math.hpp:218
Definition utility/math.hpp:44
__host__ __device__ constexpr T operator()(T a, T b) const
Definition utility/math.hpp:45
Definition utility/math.hpp:50
__host__ __device__ constexpr T operator()(T a, T b) const
Definition utility/math.hpp:51
Definition utility/math.hpp:29
__host__ __device__ constexpr T operator()(T a, T b) const
Definition utility/math.hpp:30
Definition utility/math.hpp:34
__host__ __device__ constexpr auto operator()(const A &a, const B &b) const
Definition utility/math.hpp:36
Definition utility/math.hpp:23
__host__ __device__ constexpr T operator()(T a, T b) const
Definition utility/math.hpp:24
Definition utility/math.hpp:17
__host__ __device__ constexpr T operator()(T a) const
Definition utility/math.hpp:18