18template <
typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
32 typename DimAccessOrder,
35 bool ThreadTransferSrc0ResetCoordinateAfterRun,
36 bool ThreadTransferSrc1ResetCoordinateAfterRun,
37 bool ThreadTransferSrc2ResetCoordinateAfterRun,
38 bool ThreadTransferDstResetCoordinateAfterRun>
48 const Index& src0_block_slice_origin,
49 const Src1Desc& src1_desc,
50 const Index& src1_block_slice_origin,
51 const Src2Desc& src2_desc,
52 const Index& src2_block_slice_origin,
53 const DstDesc& dst_desc,
54 const Index& dst_block_slice_origin,
55 const ElementwiseOperation& element_op)
56 : threadwise_transfer_(src0_desc,
71 nDim == ThreadClusterLengths::Size() &&
72 nDim == ThreadClusterArrangeOrder::Size() &&
73 nDim == DimAccessOrder::Size(),
74 "wrong! nDim not consistent");
78 "wrong! threads should be mapped to cover entire slicing window");
80 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
81 "wrong! ThreadGroup::GetNumOfThread() too small");
83 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
84 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
86 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
87 make_multi_index(get_thread_local_1d_id()));
89 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
91 threadwise_transfer_.SetSrc0SliceOrigin(
92 src0_desc, src0_block_slice_origin + thread_data_idx_begin);
93 threadwise_transfer_.SetSrc1SliceOrigin(
94 src1_desc, src1_block_slice_origin + thread_data_idx_begin);
95 threadwise_transfer_.SetSrc2SliceOrigin(
96 src2_desc, src2_block_slice_origin + thread_data_idx_begin);
97 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
98 dst_block_slice_origin + thread_data_idx_begin);
102 template <
typename Src0Buffer,
typename Src1Buffer,
typename Src2Buffer,
typename DstBuffer>
103 __device__
void Run(
const Src0Desc& src0_desc,
104 const Src0Buffer& src0_buf,
105 const Src1Desc& src1_desc,
106 const Src1Buffer& src1_buf,
107 const Src2Desc& src2_desc,
108 const Src2Buffer& src2_buf,
109 const DstDesc& dst_desc,
112 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
113 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
115 threadwise_transfer_.Run(
116 src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
122 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
123 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
125 threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
131 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
132 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
134 threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
140 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
141 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
143 threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
149 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
150 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
152 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
157 static constexpr auto thread_cluster_desc_ =
160 using ThreadwiseTransfer =
161 ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
169 ElementwiseOperation,
175 ThreadTransferSrc0ResetCoordinateAfterRun,
176 ThreadTransferSrc1ResetCoordinateAfterRun,
177 ThreadTransferSrc2ResetCoordinateAfterRun,
178 ThreadTransferDstResetCoordinateAfterRun>;
180 ThreadwiseTransfer threadwise_transfer_;
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r3.hpp:41
__device__ void MoveSrc2SliceWindow(const Src2Desc &src2_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:138
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:129
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:147
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:120
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc &src0_desc, const Index &src0_block_slice_origin, const Src1Desc &src1_desc, const Index &src1_block_slice_origin, const Src2Desc &src2_desc, const Index &src2_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:47
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r3.hpp:45
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r3.hpp:43
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const Src2Desc &src2_desc, const Src2Buffer &src2_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:103