device_contraction_utils.hpp Source File

device_contraction_utils.hpp Source File#

Composable Kernel: device_contraction_utils.hpp Source File
device_contraction_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cassert>
7#include <sstream>
8#include <vector>
9
10#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
32template <index_t NumDim1, index_t NumDim2>
33auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
34{
35 if(lengths.size() != NumDim1 + NumDim2)
36 {
37 std::ostringstream err;
38 err << "Incorrect number of lengths in " << "device_contraction_utils.hpp" << ":"
39 << __LINE__ << ", in function: " << __func__;
40 throw std::runtime_error(err.str());
41 }
42 if(strides.size() != NumDim1 + NumDim2)
43 {
44 std::ostringstream err;
45 err << "Incorrect number of strides in " << "device_contraction_utils.hpp" << ":"
46 << __LINE__ << ", in function: " << __func__;
47 throw std::runtime_error(err.str());
48 }
49
50 // Determine the beginning and end idx of the group representing the FCD.
51 index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
52 if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
53 {
54 // MZ or KZ are ones
55 bool dims1_are_ones = true;
56 for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
57 {
58 if(lengths[dim_idx] != 1)
59 {
60 dims1_are_ones = false;
61 }
62 }
63
64 if(dims1_are_ones)
65 {
66 begin_idx = NumDim1;
67 end_idx = NumDim1 + NumDim2 - 1;
68 continous_dim = 1;
69 }
70 else
71 {
72 begin_idx = 0;
73 end_idx = NumDim1 - 1;
74 continous_dim = 0;
75 }
76 }
77 else if(strides[NumDim1 - 1] == 1)
78 {
79 begin_idx = 0;
80 end_idx = NumDim1 - 1;
81 continous_dim = 0;
82 }
83 else if(strides[NumDim1 + NumDim2 - 1] == 1)
84 {
85 begin_idx = NumDim1;
86 end_idx = NumDim1 + NumDim2 - 1;
87 continous_dim = 1;
88 }
89 else
90 {
91 // The dimension consecutive in memory is not the last dimension of any group, so only
92 // one element can be read/written at once.
93 consecutive_stride = 1;
94 continous_dim = 0;
95 return make_tuple(continous_dim, consecutive_stride);
96 }
97
98 for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
99 {
100 if(strides[dim_idx] == consecutive_stride)
101 {
102 consecutive_stride *= lengths[dim_idx];
103 }
104 else
105 {
106 break;
107 }
108 }
109 const index_t max_subsequent_elems = consecutive_stride;
110 return make_tuple(continous_dim, max_subsequent_elems);
111}
112
113} // namespace device
114} // namespace tensor_operation
115} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_contraction_utils.hpp:33
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211