18template <
typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
30 typename DimAccessOrder,
33 bool ThreadTransferSrc0ResetCoordinateAfterRun,
34 bool ThreadTransferSrc1ResetCoordinateAfterRun,
35 bool ThreadTransferDstResetCoordinateAfterRun>
45 const Index& src0_block_slice_origin,
46 const Src1Desc& src1_desc,
47 const Index& src1_block_slice_origin,
48 const DstDesc& dst_desc,
49 const Index& dst_block_slice_origin,
50 const ElementwiseOperation& element_op)
51 : threadwise_transfer_(src0_desc,
63 nDim == ThreadClusterLengths::Size() &&
64 nDim == ThreadClusterArrangeOrder::Size() &&
65 nDim == DimAccessOrder::Size(),
66 "wrong! nDim not consistent");
70 "wrong! threads should be mapped to cover entire slicing window");
72 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
73 "wrong! ThreadGroup::GetNumOfThread() too small");
75 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
76 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
78 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
79 make_multi_index(ThreadGroup::GetThreadId()));
81 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
83 threadwise_transfer_.SetSrc0SliceOrigin(
84 src0_desc, src0_block_slice_origin + thread_data_idx_begin);
85 threadwise_transfer_.SetSrc1SliceOrigin(
86 src1_desc, src1_block_slice_origin + thread_data_idx_begin);
87 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
88 dst_block_slice_origin + thread_data_idx_begin);
92 template <
typename Src0Buffer,
typename Src1Buffer,
typename DstBuffer>
93 __device__
void Run(
const Src0Desc& src0_desc,
94 const Src0Buffer& src0_buf,
95 const Src1Desc& src1_desc,
96 const Src1Buffer& src1_buf,
97 const DstDesc& dst_desc,
100 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
101 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
103 threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
109 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
110 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
112 threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
118 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
119 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
121 threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
127 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
128 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
130 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
135 static constexpr auto thread_cluster_desc_ =
138 using ThreadwiseTransfer =
139 ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
145 ElementwiseOperation,
151 ThreadTransferSrc0ResetCoordinateAfterRun,
152 ThreadTransferSrc1ResetCoordinateAfterRun,
153 ThreadTransferDstResetCoordinateAfterRun>;
155 ThreadwiseTransfer threadwise_transfer_;
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:125
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r2.hpp:42
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r2.hpp:38
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r2.hpp:40
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc &src0_desc, const Index &src0_block_slice_origin, const Src1Desc &src1_desc, const Index &src1_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:44
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:116
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:107
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:93