device_base.hpp Source File

device_base.hpp Source File#

Composable Kernel: device_base.hpp Source File
device_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
7#include <string>
8#include <sstream>
9#include <regex>
10#include <optional>
11
12#include "ck/stream_config.hpp"
13#endif
14#include "ck/utility/get_id.hpp"
15
16namespace ck {
17namespace tensor_operation {
18namespace device {
19
20#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
21#define GET_OBJECT_NAME_IMLP \
22 std::optional<std::string> GetObjectName() const override \
23 { \
24 std::string str = __PRETTY_FUNCTION__; \
25 static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
26 std::smatch match; \
27 if(!std::regex_search(str, match, obj_name_expr)) \
28 { \
29 return str; \
30 } \
31 return std::string(match[1]) + ';'; \
32 }
33
34#define GET_TEMPLATE_INFO_IMPL \
35 std::optional<std::string> GetTemplateInfo() const override \
36 { \
37 std::string str = __PRETTY_FUNCTION__; \
38 static std::regex template_expr{"\\[(.*)\\]"}; \
39 std::smatch match; \
40 if(!std::regex_search(str, match, template_expr)) \
41 { \
42 return std::nullopt; \
43 } \
44 return std::string(match[1]); \
45 }
46
47#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
48#endif
49
50template <index_t BlockSize_,
51 index_t MPerBlock_,
52 index_t NPerBlock_,
53 index_t MPerXDL_,
54 index_t NPerXDL_,
55 index_t MXdlPerWave_,
56 bool IsWave64>
57static constexpr auto GetNXdlPerWave2()
58{
59 constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
60 constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
61 static_assert(MWaves > 0);
62
63 constexpr index_t NWaves = Waves / MWaves;
64 if constexpr(NWaves == 0)
65 {
66 return 0;
67 }
68 else
69 {
70 if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
71 {
72 return NPerBlock_ / (NWaves * NPerXDL_);
73 }
74 else
75 {
76 return 0;
77 }
78 }
79}
80
81#define GET_NXDL_PER_WAVE_IMPL \
82 template <bool IsWave64> \
83 static constexpr auto GetNXdlPerWave() \
84 { \
85 return GetNXdlPerWave2<BlockSize, \
86 MPerBlock, \
87 NPerBlock, \
88 MPerXDL, \
89 NPerXDL, \
90 MXdlPerWave, \
91 IsWave64>(); \
92 }
93
94#define INVOKER_RUN_IMPL \
95 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
96 { \
97 if(get_warp_size() == 64) \
98 { \
99 if constexpr(NXdlPerWave64 > 0) \
100 { \
101 return RunImp<GridwiseGemm64>(arg, stream_config); \
102 } \
103 } \
104 else \
105 { \
106 if constexpr(NXdlPerWave32 > 0) \
107 { \
108 return RunImp<GridwiseGemm32>(arg, stream_config); \
109 } \
110 } \
111 return 0; \
112 }
113
114#define INVOKER_RUN3_IMPL \
115 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
116 { \
117 if(get_warp_size() == 64) \
118 { \
119 if constexpr(NXdlPerWave64 > 0) \
120 { \
121 return RunImp<GridwiseGemm64>(arg, stream_config); \
122 } \
123 } \
124 else \
125 { \
126 if constexpr(NXdlPerWave32 > 0) \
127 { \
128 return RunImp<GridwiseGemm32>( \
129 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
130 stream_config); \
131 } \
132 } \
133 return 0; \
134 }
135
136template <index_t BlockSize,
137 index_t MPerBlock,
138 index_t NPerBlock,
139 index_t MPerXdl,
140 index_t NPerXdl,
141 index_t MXdlPerWave,
142 index_t NXdlPerWave,
143 typename CDataType,
144 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
145__device__ static bool constexpr IsValidGemmCompilationParameter()
146{
147#if defined(__gfx11__) || defined(__gfx12__)
148 if constexpr(MPerXdl != 16 || NPerXdl != 16)
149 {
150 return false;
151 }
152#endif
153
154#if defined(__gfx11__)
155 constexpr bool SupportMemOp = CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set;
156#else
157 constexpr bool SupportMemOp =
158 sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set);
159#endif
160 if constexpr(SupportMemOp == false)
161 {
162 return false;
163 }
164
165 if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
166 {
167 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
168 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
169 if constexpr(MWaves > 0 && NWaves > 0)
170 {
171 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
172 return WaveSize == get_warp_size();
173 }
174 }
175 return false;
176}
177
178#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
179 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
180 InMemoryDataOperationEnum::Set> \
181 __device__ static bool constexpr IsValidCompilationParameter() \
182 { \
183 return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
184 BlockSize, \
185 MPerBlock, \
186 NPerBlock, \
187 MPerXdl, \
188 NPerXdl, \
189 MXdlPerWave, \
190 NXdlPerWave, \
191 CDataType_, \
192 CGlobalMemoryDataOperation_>(); \
193 }
194
195#ifndef CK_CODE_GEN_RTC
197{
198 BaseArgument() = default;
199 BaseArgument(const BaseArgument&) = default;
201
202 virtual ~BaseArgument() {}
203
204 void* p_workspace_ = nullptr;
205};
206
208{
209 BaseInvoker() = default;
210 BaseInvoker(const BaseInvoker&) = default;
212
213 virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
214 {
215 return float{0};
216 }
217
218 virtual ~BaseInvoker() {}
219};
220#endif
221
223{
224 BaseOperator() = default;
225 BaseOperator(const BaseOperator&) = default;
227#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
228 virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
229 virtual std::string GetTypeString() const { return ""; }
230 virtual std::string GetInstanceString() const { return ""; }
231
232 virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
233
234 virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
235
236 virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
237
238 virtual std::string GetTypeIdHashCode() const
239 {
240 std::ostringstream oss;
241
242 oss << std::hex << typeid(*this).hash_code();
243
244 return oss.str();
245 };
246
247 virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
248
250 void* p_workspace,
251 const StreamConfig& = StreamConfig{}) const
252 {
253 assert(p_arg);
254 p_arg->p_workspace_ = p_workspace;
255 }
256#endif
257 virtual ~BaseOperator() {}
258};
259
260} // namespace device
261} // namespace tensor_operation
262} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
Definition ck/stream_config.hpp:10
Definition device_base.hpp:197
BaseArgument(const BaseArgument &)=default
virtual ~BaseArgument()
Definition device_base.hpp:202
void * p_workspace_
Definition device_base.hpp:204
BaseArgument & operator=(const BaseArgument &)=default
BaseInvoker & operator=(const BaseInvoker &)=default
virtual ~BaseInvoker()
Definition device_base.hpp:218
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition device_base.hpp:213
BaseInvoker(const BaseInvoker &)=default
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
virtual std::optional< std::string > GetTemplateInfo() const
Definition device_base.hpp:236
virtual bool IsSupportedArgument(const BaseArgument *)
Definition device_base.hpp:228
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition device_base.hpp:247
virtual std::string GetTypeString() const
Definition device_base.hpp:229
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition device_base.hpp:238
BaseOperator & operator=(const BaseOperator &)=default
virtual std::optional< std::string > GetObjectName() const
Definition device_base.hpp:234
virtual std::string GetTypeIdName() const
Definition device_base.hpp:232
virtual ~BaseOperator()
Definition device_base.hpp:257