6#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
20#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
21#define GET_OBJECT_NAME_IMLP \
22 std::optional<std::string> GetObjectName() const override \
24 std::string str = __PRETTY_FUNCTION__; \
25 static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
27 if(!std::regex_search(str, match, obj_name_expr)) \
31 return std::string(match[1]) + ';'; \
34#define GET_TEMPLATE_INFO_IMPL \
35 std::optional<std::string> GetTemplateInfo() const override \
37 std::string str = __PRETTY_FUNCTION__; \
38 static std::regex template_expr{"\\[(.*)\\]"}; \
40 if(!std::regex_search(str, match, template_expr)) \
42 return std::nullopt; \
44 return std::string(match[1]); \
47#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
57static constexpr auto GetNXdlPerWave2()
59 constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
60 constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
61 static_assert(MWaves > 0);
63 constexpr index_t NWaves = Waves / MWaves;
64 if constexpr(NWaves == 0)
70 if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
72 return NPerBlock_ / (NWaves * NPerXDL_);
81#define GET_NXDL_PER_WAVE_IMPL \
82 template <bool IsWave64> \
83 static constexpr auto GetNXdlPerWave() \
85 return GetNXdlPerWave2<BlockSize, \
94#define INVOKER_RUN_IMPL \
95 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
97 if(get_warp_size() == 64) \
99 if constexpr(NXdlPerWave64 > 0) \
101 return RunImp<GridwiseGemm64>(arg, stream_config); \
106 if constexpr(NXdlPerWave32 > 0) \
108 return RunImp<GridwiseGemm32>(arg, stream_config); \
114#define INVOKER_RUN3_IMPL \
115 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
117 if(get_warp_size() == 64) \
119 if constexpr(NXdlPerWave64 > 0) \
121 return RunImp<GridwiseGemm64>(arg, stream_config); \
126 if constexpr(NXdlPerWave32 > 0) \
128 return RunImp<GridwiseGemm32>( \
129 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
145__device__
static bool constexpr IsValidGemmCompilationParameter()
147#if defined(__gfx11__) || defined(__gfx12__)
148 if constexpr(MPerXdl != 16 || NPerXdl != 16)
154#if defined(__gfx11__)
157 constexpr bool SupportMemOp =
160 if constexpr(SupportMemOp ==
false)
165 if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
167 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
168 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
169 if constexpr(MWaves > 0 && NWaves > 0)
171 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
178#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
179 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
180 InMemoryDataOperationEnum::Set> \
181 __device__ static bool constexpr IsValidCompilationParameter() \
183 return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
192 CGlobalMemoryDataOperation_>(); \
195#ifndef CK_CODE_GEN_RTC
227#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
234 virtual std::optional<std::string>
GetObjectName()
const {
return std::nullopt; }
240 std::ostringstream oss;
242 oss << std::hex <<
typeid(*this).hash_code();
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
Definition ck/stream_config.hpp:10
Definition device_base.hpp:197
BaseArgument(const BaseArgument &)=default
virtual ~BaseArgument()
Definition device_base.hpp:202
void * p_workspace_
Definition device_base.hpp:204
BaseArgument & operator=(const BaseArgument &)=default
BaseInvoker & operator=(const BaseInvoker &)=default
virtual ~BaseInvoker()
Definition device_base.hpp:218
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition device_base.hpp:213
BaseInvoker(const BaseInvoker &)=default
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
virtual std::optional< std::string > GetTemplateInfo() const
Definition device_base.hpp:236
virtual bool IsSupportedArgument(const BaseArgument *)
Definition device_base.hpp:228
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition device_base.hpp:247
virtual std::string GetTypeString() const
Definition device_base.hpp:229
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition device_base.hpp:238
BaseOperator & operator=(const BaseOperator &)=default
virtual std::optional< std::string > GetObjectName() const
Definition device_base.hpp:234
virtual std::string GetTypeIdName() const
Definition device_base.hpp:232
virtual ~BaseOperator()
Definition device_base.hpp:257