device_put_element_impl.hpp Source File

device_put_element_impl.hpp Source File#

Composable Kernel: device_put_element_impl.hpp Source File
device_put_element_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
16
17namespace ck {
18namespace tensor_operation {
19namespace device {
20
21// output[indices] = input
22template <typename InDataType,
23 typename IndexDataType,
24 typename OutDataType,
25 typename ElementwiseOperation,
27 ck::index_t InVectorSize>
29 : public DevicePutElement<InDataType, IndexDataType, OutDataType, ElementwiseOperation, MemOp>
30{
31 template <typename Desc_M>
32 static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
33 {
34 constexpr auto I0 = Number<0>{};
35
36 const auto m = desc_m.GetLength(I0);
37 const index_t loop_step = gridSize * blockSize * InVectorSize;
38 const auto pad = math::integer_least_multiple(m, loop_step) - m;
39 const auto desc_m_pad =
44 return desc_m_pad;
45 }
46
47 static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
48 {
49 const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
50 return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
51 }
52
53 using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1));
54
56 InDataType,
57 IndexDataType,
58 OutDataType,
59 ElementwiseOperation,
60 MemOp,
61 InVectorSize>;
62
63 struct Argument : public BaseArgument
64 {
65 Argument(const InDataType* p_input,
66 const IndexDataType* p_indices,
67 OutDataType* p_output,
68 index_t input_length,
69 ElementwiseOperation elementwise_op)
70 : p_input_{p_input},
71 p_indices_{p_indices},
72 p_output_{p_output},
73 input_length_raw_{input_length},
74 elementwise_op_{elementwise_op},
75 blockSize_{256}
76 {
77 }
78
79 const InDataType* p_input_;
80 const IndexDataType* p_indices_;
81 OutDataType* p_output_;
83 ElementwiseOperation elementwise_op_;
85 };
86
87 struct Invoker : public BaseInvoker
88 {
89 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
90 {
91 index_t gridSize = getAvailableComputeUnitCount(stream_config);
92 InGrid1dDesc in_grid_desc =
94
95 const auto kernel = kernel_put_element_1d<GridwisePutElement,
97 InDataType,
98 IndexDataType,
99 OutDataType,
100 ElementwiseOperation>;
101
102 float elapsed_time = launch_and_time_kernel(stream_config,
103 kernel,
104 dim3(gridSize),
105 dim3(arg.blockSize_),
106 0,
107 in_grid_desc,
108 arg.p_input_,
109 arg.p_indices_,
110 arg.p_output_,
111 arg.elementwise_op_);
112 return elapsed_time;
113 }
114
115 float Run(const BaseArgument* p_arg,
116 const StreamConfig& stream_config = StreamConfig{}) override
117 {
118 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
119 }
120 };
121
122 bool IsSupportedArgument(const BaseArgument* p_arg) override
123 {
124 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
125
126 if(pArg->input_length_raw_ % InVectorSize != 0)
127 {
128 return false;
129 }
130 return true;
131 }
132
133 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_input,
134 const void* p_indices,
135 void* p_output,
136 index_t input_length,
137 index_t,
138 ElementwiseOperation elementwise_op) override
139 {
140 return std::make_unique<Argument>(static_cast<const InDataType*>(p_input),
141 static_cast<const IndexDataType*>(p_indices),
142 static_cast<OutDataType*>(p_output),
143 input_length,
144 elementwise_op);
145 }
146
147 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
148 {
149 return std::make_unique<Invoker>(Invoker{});
150 }
151};
152
153} // namespace device
154} // namespace tensor_operation
155} // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition gridwise_put_element_1d.hpp:17
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_put_element_1d.hpp:36
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_put_element.hpp:22
Definition device_put_element_impl.hpp:64
const IndexDataType * p_indices_
Definition device_put_element_impl.hpp:80
Argument(const InDataType *p_input, const IndexDataType *p_indices, OutDataType *p_output, index_t input_length, ElementwiseOperation elementwise_op)
Definition device_put_element_impl.hpp:65
ElementwiseOperation elementwise_op_
Definition device_put_element_impl.hpp:83
index_t blockSize_
Definition device_put_element_impl.hpp:84
index_t input_length_raw_
Definition device_put_element_impl.hpp:82
OutDataType * p_output_
Definition device_put_element_impl.hpp:81
const InDataType * p_input_
Definition device_put_element_impl.hpp:79
Definition device_put_element_impl.hpp:88
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_put_element_impl.hpp:115
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_put_element_impl.hpp:89
Definition device_put_element_impl.hpp:30
decltype(MakeDescriptor_M(1, 1, 1)) InGrid1dDesc
Definition device_put_element_impl.hpp:53
GridwisePutElement_1D< InGrid1dDesc, InDataType, IndexDataType, OutDataType, ElementwiseOperation, MemOp, InVectorSize > GridwisePutElement
Definition device_put_element_impl.hpp:55
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
Definition device_put_element_impl.hpp:47
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_put_element_impl.hpp:122
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
Definition device_put_element_impl.hpp:32
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_put_element_impl.hpp:147
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_input, const void *p_indices, void *p_output, index_t input_length, index_t, ElementwiseOperation elementwise_op) override
Definition device_put_element_impl.hpp:133