add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File

add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File#

Composable Kernel: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File
add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
10// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
11template <typename ADataType_,
12 typename BDataType_,
13 typename GammaDataType_,
14 typename ComputeDataType_,
15 typename XDataType_,
16 typename YScaleDataType_,
17 typename QYDataType_,
18 typename BlockShape_,
19 bool kPadN_,
20 bool kSaveX_,
21 bool kThreePass_>
23{
32
33 static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
34 static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
35
36 static constexpr bool kPadN = kPadN_;
37 static constexpr bool kSaveX = kSaveX_;
38 static constexpr bool kThreePass = kThreePass_;
39};
40
41} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:23
static constexpr bool kPadN
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:36
remove_cvref_t< ADataType_ > ADataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:24
static constexpr bool kNeedCrossLaneSync
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:33
static constexpr bool kNeedCrossWarpSync
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:34
static constexpr bool kThreePass
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:38
remove_cvref_t< QYDataType_ > QYDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:30
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:27
static constexpr bool kSaveX
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:37
remove_cvref_t< GammaDataType_ > GammaDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:26
remove_cvref_t< BlockShape_ > BlockShape
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:31
remove_cvref_t< BDataType_ > BDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:25
remove_cvref_t< XDataType_ > XDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:28
remove_cvref_t< YScaleDataType_ > YScaleDataType
Definition add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:29