gridwise_gemm_waveletmodel.hpp Source File

gridwise_gemm_waveletmodel.hpp Source File#

Composable Kernel: gridwise_gemm_waveletmodel.hpp Source File
gridwise_gemm_waveletmodel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
12
13// 1-stage prefetch
14template <typename TileLoadThreadGroup>
15struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
16{
17 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
18 {
19 // TODO: improve applicability
20 return true;
21 }
22
23 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
24 {
25 return num_loop > 1;
26 }
27
28 template <bool HasMainLoop,
29 typename AGridDesc,
30 typename ABlockDesc,
31 typename ABlockTransfer,
32 typename AGridBuffer,
33 typename ABlockBuffer,
34 typename ABlockTransferStep,
35 typename BGridDesc,
36 typename BBlockDesc,
37 typename BBlockTransfer,
38 typename BGridBuffer,
39 typename BBlockBuffer,
40 typename BBlockTransferStep>
41 static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
42 const ABlockDesc& a_block_desc,
43 ABlockTransfer& a_blockwise_copy,
44 const AGridBuffer& a_grid_buf,
45 ABlockBuffer& a_block_buf,
46 const ABlockTransferStep& a_block_copy_step,
47 const BGridDesc& b_grid_desc,
48 const BBlockDesc& b_block_desc,
49 BBlockTransfer& b_blockwise_copy,
50 const BGridBuffer& b_grid_buf,
51 BBlockBuffer& b_block_buf,
52 const BBlockTransferStep& b_block_copy_step,
53 index_t num_loop)
54 {
55 // global read 0
56 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
57 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
58
59 // move to 1
60 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
61 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
62
63 // LDS write 0
64 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
65 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
66
67 if constexpr(HasMainLoop)
68 {
69 index_t i = 0;
70
71 do
72 {
73 // sync for Load threads()
75 // global read i + 1
76 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
77 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
78
79 // move to i + 2
80 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
81 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
82
83 // sync with math threads()
85
86 // LDS write i+1
87 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
88 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
89
90 ++i;
91 } while(i < (num_loop - 1));
92 }
93
94 // tail
95 {
97 // GEMM num_loop - 1
98 }
99 }
100};
101
102template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
104// 1- stage prefetch
105template <typename TileMathThreadGroup>
106struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
107{
108
109 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
110
111 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
112 {
113 return num_loop > 1;
114 }
115
116 template <bool HasMainLoop,
117 typename ABlockBuffer,
118 typename BBlockBuffer,
119 typename BlockwiseGemm,
120 typename CThreadBuffer>
121 static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
122 BBlockBuffer& b_block_buf,
123 const BlockwiseGemm& block_gemm,
124 CThreadBuffer& c_thread_buf,
125 index_t num_loop)
126 {
127 // Initialize C
128 c_thread_buf.Clear();
129
130 // main body
131 if constexpr(HasMainLoop)
132 {
133 index_t i = 0;
134
135 do
136 {
138
139 // GEMM i
140 block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
141
143 ++i;
144 } while(i < (num_loop - 1));
145 }
146
147 // tail
148 {
150
151 // GEMM num_loop - 1
152 block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
153 }
154 }
155};
156
157} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__device__ void block_sync_lds()
Definition synchronization.hpp:16
static __device__ void RunLoadWavePipeline(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, index_t num_loop)
Definition gridwise_gemm_waveletmodel.hpp:41
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_waveletmodel.hpp:17
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_waveletmodel.hpp:23
Definition gridwise_gemm_waveletmodel.hpp:11
static __device__ void RunMathWavePipeline(ABlockBuffer &a_block_buf, BBlockBuffer &b_block_buf, const BlockwiseGemm &block_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_waveletmodel.hpp:121
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_waveletmodel.hpp:109
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_waveletmodel.hpp:111
Definition gridwise_gemm_waveletmodel.hpp:103