device_conv_fwd.hpp Source File

device_conv_fwd.hpp Source File#

Composable Kernel: device_conv_fwd.hpp Source File
device_conv_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14template <ck::index_t NumDimSpatial,
15 typename InLayout,
16 typename WeiLayout,
17 typename OutLayout,
18 typename InDataType,
19 typename WeiDataType,
20 typename OutDataType,
21 typename InElementwiseOperation,
22 typename WeiElementwiseOperation,
23 typename OutElementwiseOperation>
25{
26 virtual std::unique_ptr<BaseArgument>
27 MakeArgumentPointer(const void* p_in,
28 const void* p_wei,
29 void* p_out,
33 std::vector<ck::index_t> input_spatial_lengths,
34 std::vector<ck::index_t> filter_spatial_lengths,
35 std::vector<ck::index_t> output_spatial_lengths,
36 std::vector<ck::index_t> conv_filter_strides,
37 std::vector<ck::index_t> conv_filter_dilations,
38 std::vector<ck::index_t> input_left_pads,
39 std::vector<ck::index_t> input_right_pads,
40 InElementwiseOperation in_element_op,
41 WeiElementwiseOperation wei_element_op,
42 OutElementwiseOperation out_element_op) = 0;
43
44 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
45};
46
47} // namespace device
48} // namespace tensor_operation
49} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_conv_fwd.hpp:25
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)=0