reduction_functions_accumulate.hpp Source File

reduction_functions_accumulate.hpp Source File#

Composable Kernel: reduction_functions_accumulate.hpp Source File
reduction_functions_accumulate.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12namespace detail {
13
14// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
15template <typename ReduceOperation, typename AccDataType>
17{
18 __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
19 {
20 if(!ck::math::isnan(currVal))
21 {
22 ReduceOperation{}(accuVal, currVal);
23 }
24 };
25};
26
27template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
29
30// Does not check for NaN; does not guarantee NaNs be propagated to result
31// e.g., given that max(a, b) = a > b ? a : b
32// then max(NaN, 1) returns 1
33// max(1, NaN) returns NaN
34// since any comparison involving NaNs returns false
35template <typename ReduceOperation, typename AccDataType>
36struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
37{
38 // cppcheck-suppress constParameter
39 __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
40 {
41 ReduceOperation{}(accuVal, currVal);
42 };
43};
44
45// Check for NaN; guarantees NaNs be propagated to result
46template <typename ReduceOperation, typename AccDataType>
47struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
48{
49 __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
50 {
51 using ck::math::isnan;
52
53 if(isnan(currVal))
54 {
55 accuVal = currVal;
56 }
57 else
58 {
59 ReduceOperation{}(accuVal, currVal);
60 };
61 };
62};
63
64template <bool PropagateNan, typename ReduceOperation, typename AccDataType, typename IndexDataType>
66
67template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
68struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
69{
70 __host__ __device__ static inline void
71 // cppcheck-suppress constParameter
72 Calculate(AccDataType& accuVal,
73 AccDataType currVal,
74 IndexDataType& accuIndex,
75 IndexDataType currIndex)
76 {
77 bool changed = false;
78
79 ReduceOperation{}(accuVal, currVal, changed);
80
81 if(changed)
82 accuIndex = currIndex;
83 };
84};
85
86template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
87struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
88{
89 // The method is called when the ReduceOperation is indexable and the user asked for indices
90 __host__ __device__ static inline void Calculate(AccDataType& accuVal,
91 AccDataType currVal,
92 IndexDataType& accuIndex,
93 IndexDataType currIndex)
94 {
95 using ck::math::isnan;
96
97 if(isnan(currVal))
98 {
99 accuVal = currVal;
100 accuIndex = currIndex;
101 }
102 else
103 {
104 bool changed = false;
105
106 ReduceOperation{}(accuVal, currVal, changed);
107
108 if(changed)
109 accuIndex = currIndex;
110 }
111 };
112};
113
114} // namespace detail
115} // namespace ck
Definition threadwise_tensor_slice_transfer_util.hpp:15
Definition ck.hpp:268
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal, IndexDataType &accuIndex, IndexDataType currIndex)
Definition reduction_functions_accumulate.hpp:72
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal, IndexDataType &accuIndex, IndexDataType currIndex)
Definition reduction_functions_accumulate.hpp:90
Definition reduction_functions_accumulate.hpp:65
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition reduction_functions_accumulate.hpp:39
__host__ static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition reduction_functions_accumulate.hpp:49
Definition reduction_functions_accumulate.hpp:28
Definition reduction_functions_accumulate.hpp:17
static __device__ void Calculate(AccDataType &accuVal, AccDataType currVal)
Definition reduction_functions_accumulate.hpp:18