device_normalization_bwd_gamma_beta.hpp Source File

device_normalization_bwd_gamma_beta.hpp Source File#

Composable Kernel: device_normalization_bwd_gamma_beta.hpp Source File
device_normalization_bwd_gamma_beta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14template <typename DYDataType,
15 typename XDataType,
16 typename MeanInvStdDataType,
17 typename DGammaDataType,
18 typename DBetaDataType,
19 index_t Rank,
20 index_t NumReduceDim>
22{
23 virtual std::unique_ptr<BaseArgument>
24 MakeArgumentPointer(const std::vector<index_t> inLengths,
25 const std::vector<index_t> dyStrides,
26 const std::vector<index_t> xStrides,
27 const std::vector<index_t> meanStrides,
28 const std::vector<index_t> invStdStrides,
29 const std::vector<index_t> outLengths,
30 const std::vector<index_t> dgammaStrides,
31 const std::vector<index_t> dbetaStrides,
32 const std::vector<index_t> reduceDims,
33 const void* p_dy,
34 const void* p_x,
35 const void* p_mean,
36 const void* p_invStd,
37 void* p_dgamma,
38 void* p_dbeta) = 0;
39
40 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
41};
42
43template <typename DYDataType,
44 typename XDataType,
45 typename MeanInvStdDataType,
46 typename DGammaDataType,
47 typename DBetaDataType,
48 index_t Rank,
49 index_t NumReduceDim>
51 std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType,
52 XDataType,
53 MeanInvStdDataType,
54 DGammaDataType,
55 DBetaDataType,
56 Rank,
57 NumReduceDim>>;
58
59} // namespace device
60} // namespace tensor_operation
61} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceNormalizationBwdGammaBeta< DYDataType, XDataType, MeanInvStdDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim > > DeviceNormalizationBwdGammaBetaPtr
Definition device_normalization_bwd_gamma_beta.hpp:50
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_normalization_bwd_gamma_beta.hpp:22
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > outLengths, const std::vector< index_t > dgammaStrides, const std::vector< index_t > dbetaStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_mean, const void *p_invStd, void *p_dgamma, void *p_dbeta)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0