1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/gpu/kernel/dynamic_akg/dynamic_utils.h"
18 #include <fstream>
19 #include <algorithm>
20 #include <map>
21 #include "nlohmann/json.hpp"
22 #include "utils/ms_utils.h"
23 #include "kernel/framework_utils.h"
24 #include "mindspore/ccsrc/include/common/debug/common.h"
25 #include "plugin/device/gpu/hal/device/gpu_common.h"
26
27 namespace mindspore {
28 namespace kernel {
29 using std::fstream;
30 using std::string;
31 using std::unordered_map;
32 using std::vector;
33 namespace {
34 constexpr auto kV100Device = "v100";
35 constexpr auto kA100Device = "a100";
36 constexpr auto kSharedMem = "shared_mem";
37 constexpr auto kRegMem = "reg_mem";
38 constexpr auto kLocalPrefix = "Seq";
39 constexpr auto kHostShapes = "hostShapes";
40 constexpr auto kDeviceShapes = "deviceShapes";
41 constexpr auto kRuntimeVars = "runtimeVars";
42 constexpr auto kTargetInfo = "targetInfo";
43 constexpr auto kSupportInfo = "SupportInfo";
44 constexpr auto kReduceSizeStatic = "ReduceSizeStatic";
45 constexpr auto kDynAlgorithm = "DynAlgorithm";
46 constexpr auto kEnableAtomic = "EnableAtomic";
47 constexpr auto kMultiply = "*";
48
49 constexpr auto kBlockIdxX = "blockIdx.x";
50 constexpr auto kBlockIdxY = "blockIdx.y";
51 constexpr auto kBlockIdxZ = "blockIdx.z";
52 constexpr auto kThreadIdxX = "threadIdx.x";
53 constexpr auto kThreadIdxY = "threadIdx.y";
54 constexpr auto kThreadIdxZ = "threadIdx.z";
55 constexpr int64_t kRemove = -100000;
56 constexpr int64_t kKeep = -99999;
57 const int AKG_KERNEL_MOD_BX_IDX = 0;
58 const int AKG_KERNEL_MOD_BY_IDX = 1;
59 const int AKG_KERNEL_MOD_BZ_IDX = 2;
60 const int AKG_KERNEL_MOD_TX_IDX = 3;
61 const int AKG_KERNEL_MOD_TY_IDX = 4;
62 const int AKG_KERNEL_MOD_TZ_IDX = 5;
63 const int WARP_SIZE = 32;
64 const int WARP_ALLOC_GRAN = 4;
65 const int ELEM_BEST_GRID_SIZE = 512;
66 const int MAX_THREAD_NUM = 1024;
67 const int k48KB = 49152;
68 const int k64KB = 65536;
69 const int kNum80 = 80;
70 const int kNum108 = 108;
71 const int kNum512 = 512;
72 const int kNum1024 = 1024;
73 const int kNum2 = 2;
74 const int kNum16 = 16;
75 const int kNum32 = 32;
76 const int kNum64 = 64;
77 const int kNum256 = 256;
78
79 const int AKG_DYN_ALGO_REDUCE_X = 10;
80 const int AKG_DYN_ALGO_REDUCE_Y = 11;
81 const int AKG_DYN_ALGO_REDUCE_SMALL = 12;
82 const int AKG_DYN_MARK_THREAD_LOWER_BOUND = 10;
83 const int AKG_DYN_MARK_THREAD_UPPER_BOUND = 20;
84 const int AKG_DYN_MARK_SEQ_LOWER_BOUND = 20;
85 const int AKG_DYN_MARK_SEQ_UPPER_BOUND = 30;
86 const int AKG_DYN_MARK_ONE = 30;
87 const int AKG_DYN_MARK_PRODUCT = 40;
88 const int AKG_DYN_MARK_UNKNOWN = 999;
89 } // namespace
90 struct RuntimeVarCompare {
operator ()mindspore::kernel::RuntimeVarCompare91 bool operator()(const std::pair<uint32_t, RuntimeVarPtr> &a, const std::pair<uint32_t, RuntimeVarPtr> &b) const {
92 return a.second->argIndex < b.second->argIndex;
93 }
94 };
95
96 std::unordered_map<std::string, int> AkgKernelImplInfo::algo_to_int_ = {{"reduce-x", AKG_DYN_ALGO_REDUCE_X},
97 {"reduce-y", AKG_DYN_ALGO_REDUCE_Y},
98 {"reduce-small", AKG_DYN_ALGO_REDUCE_SMALL}};
99
100 std::unordered_map<std::string, int> RuntimeVar::mark_table_ = {{"unknown", 999},
101 {"reduce-thread-last", 10},
102 {"reduce-thread", 11},
103 {"parallel-thread-last", 12},
104 {"parallel-thread", 13},
105 {"reduce-y-seq", 20},
106 {"reduce-x-seq", 21},
107 {"parallel-seq", 22},
108 {"1", 30},
109 {"product", 40}};
110
GetProperReduceXConfig(int redSize,const bool & useAtomicFlag,int * block_num,int * thread_num,int * seq_num)111 inline void GetProperReduceXConfig(int redSize, const bool &useAtomicFlag, int *block_num, int *thread_num,
112 int *seq_num) {
113 const int acc_num = 4;
114 *block_num = useAtomicFlag ? (redSize - 1) / MAX_THREAD_NUM + 1 : 1;
115 redSize = useAtomicFlag ? (redSize - 1) / (*block_num) + 1 : redSize;
116 (*thread_num) = redSize < kNum32 ? redSize : kNum32;
117 while ((*thread_num) * kNum2 * acc_num <= redSize && (*thread_num) <= MAX_THREAD_NUM) {
118 (*thread_num) *= kNum2;
119 }
120 (*seq_num) = (redSize - 1) / ((*thread_num) * (*block_num)) + 1;
121 }
122
GetProperReduceYConfig(const int & redSize,const bool & useAtomicFlag,int * block_num,int * seq_num)123 inline void GetProperReduceYConfig(const int &redSize, const bool &useAtomicFlag, int *block_num, int *seq_num) {
124 *block_num = 1;
125 *seq_num = redSize;
126 if (useAtomicFlag) {
127 if (redSize < kNum256) {
128 *seq_num = kNum16;
129 } else if (redSize < kNum1024) {
130 *seq_num = kNum32;
131 } else {
132 *seq_num = kNum64;
133 }
134 *block_num = (redSize - 1) / (*seq_num) + 1;
135 }
136 }
137
GetReduceTileSize(const int & mark,const int & upper_bound,const int & proper_thread,const int & proper_seq,int * tile_size)138 inline void GetReduceTileSize(const int &mark, const int &upper_bound, const int &proper_thread, const int &proper_seq,
139 int *tile_size) {
140 if (AKG_DYN_MARK_THREAD_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_THREAD_UPPER_BOUND) {
141 *tile_size = std::min(upper_bound, proper_thread);
142 } else if (AKG_DYN_MARK_SEQ_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_SEQ_UPPER_BOUND) {
143 *tile_size = std::min(upper_bound, proper_seq);
144 } else if (mark == AKG_DYN_MARK_ONE) {
145 *tile_size = 1;
146 } else {
147 *tile_size = -1;
148 }
149 }
150
GetProposalParallelSize(int problemSize,const std::string & device_target)151 static std::pair<int, int> GetProposalParallelSize(int problemSize, const std::string &device_target) {
152 GpuInfo &gpu_info = GpuInfo::GetInstance(device_target);
153 int proposedGrid = 1;
154 int proposedBlock = 1;
155 auto numSm = gpu_info.GetNumSm();
156 auto threadCoef = gpu_info.GetThreadCoef();
157 auto warpSizes = gpu_info.GetWarpSizes();
158 auto activeBlocksPerSm = gpu_info.GetActiveBlocksPerSm();
159 auto totalBlocks = gpu_info.GetTotalAvailableBlocks();
160 if (problemSize <= warpSizes) {
161 proposedBlock = warpSizes;
162 } else if (problemSize <= warpSizes * numSm) {
163 proposedBlock = warpSizes;
164 proposedGrid = numSm;
165 } else if (problemSize <= warpSizes * threadCoef.first * numSm * activeBlocksPerSm.first) {
166 proposedBlock = warpSizes * threadCoef.first;
167 proposedGrid = numSm * activeBlocksPerSm.first;
168 } else if (problemSize <= warpSizes * threadCoef.second * numSm * activeBlocksPerSm.second) {
169 proposedBlock = warpSizes * threadCoef.second;
170 proposedGrid = numSm * activeBlocksPerSm.second;
171 } else if (problemSize <= warpSizes * threadCoef.second * numSm * activeBlocksPerSm.second * numSm) {
172 proposedBlock = totalBlocks;
173 proposedGrid = numSm * activeBlocksPerSm.second;
174 } else {
175 // extremely large shape
176 proposedBlock = totalBlocks;
177 proposedGrid = numSm * activeBlocksPerSm.second * kNum2;
178 }
179 return std::make_pair(proposedGrid, proposedBlock);
180 }
181
GetMapLimit(size_t id,const std::string & device_target)182 uint32_t MappingInfo::GetMapLimit(size_t id, const std::string &device_target) {
183 if (id < AKG_KERNEL_MOD_BX_IDX || id > AKG_KERNEL_MOD_TZ_IDX) {
184 MS_EXCEPTION(RuntimeError) << "Map id should be in range [" << AKG_KERNEL_MOD_BX_IDX << ", "
185 << AKG_KERNEL_MOD_TZ_IDX << "], but got " << id;
186 }
187 GpuInfo &gpu_info = GpuInfo::GetInstance(device_target);
188 if (id >= AKG_KERNEL_MOD_TX_IDX) {
189 auto max_block = gpu_info.GetMaxBlocks();
190 return std::min<uint32_t>(max_block[id - AKG_KERNEL_MOD_TX_IDX],
191 gpu_info.GetTotalAvailableBlocks() / total_alloc_block);
192 } else {
193 auto max_grid = gpu_info.GetMaxGrids();
194 return max_grid[id];
195 }
196 }
197
UpdateCurrMapSize(size_t id,uint32_t map_size)198 void MappingInfo::UpdateCurrMapSize(size_t id, uint32_t map_size) {
199 if (id < AKG_KERNEL_MOD_BX_IDX || id > AKG_KERNEL_MOD_TZ_IDX) {
200 MS_EXCEPTION(RuntimeError) << "Map id should be in range [" << AKG_KERNEL_MOD_BX_IDX << ", "
201 << AKG_KERNEL_MOD_TZ_IDX << "], but got " << id;
202 }
203 if (id >= AKG_KERNEL_MOD_TX_IDX) {
204 curr_block[id - AKG_KERNEL_MOD_TX_IDX] = map_size;
205 total_alloc_block *= map_size;
206 } else {
207 curr_grid[id] = map_size;
208 total_alloc_grid *= map_size;
209 }
210 }
211
InitGpuMemoryLimit(const std::string & device_type)212 void GpuInfo::InitGpuMemoryLimit(const std::string &device_type) {
213 auto CollectLimit = [this, &device_type](const std::string &scope, GpuMemScope mem) {
214 if (device_type == kV100Device) {
215 if (scope == kSharedMem) {
216 gpuMemLimit[mem] = k48KB;
217 } else if (scope == kRegMem) {
218 gpuMemLimit[mem] = k64KB;
219 }
220 } else if (device_type == kA100Device) {
221 if (scope == kSharedMem) {
222 gpuMemLimit[mem] = k64KB;
223 } else if (scope == kRegMem) {
224 gpuMemLimit[mem] = k64KB;
225 }
226 }
227 };
228 CollectLimit(kSharedMem, MEM_SCOPE_SHARED);
229 CollectLimit(kRegMem, MEM_SCOPE_LOCAL);
230 gpuMemLimit[MEM_SCOPE_GM] = 0;
231 }
232
InitGpuComputeCapability(const std::string & device_type)233 void GpuInfo::InitGpuComputeCapability(const std::string &device_type) {
234 if (device_type == kV100Device) {
235 numSm = kNum80;
236 totalAvailableBlocks = kNum512;
237 } else if (device_type == kA100Device) {
238 numSm = kNum108;
239 totalAvailableBlocks = kNum1024;
240 }
241 }
242
AkgKernelImplInfo(const std::string & kernel_name,nlohmann::json json)243 AkgKernelImplInfo::AkgKernelImplInfo(const std::string &kernel_name, nlohmann::json json) {
244 kernel_name_ = kernel_name;
245 parsed_js_ = json;
246 device_target_ = parsed_js_[kTargetInfo];
247 if (device_target_.empty()) {
248 device_target_ = kV100Device;
249 }
250 }
251
GetHostLocationVec(std::string symbol_expr,const size_t pure_num_flag)252 LocVector AkgKernelImplInfo::GetHostLocationVec(std::string symbol_expr, const size_t pure_num_flag) {
253 std::string delimiter = kMultiply;
254 std::vector<std::string> symbol_vec;
255 if (symbol_expr.find(delimiter) != std::string::npos) {
256 // multiplication expr after folding like 's0*1024*s1'
257 size_t pos_start = 0;
258 size_t pos_end;
259 size_t delim_len = 1;
260 std::string symbol;
261 // split each symbol or number and store in a list
262 while ((pos_end = symbol_expr.find(delimiter, pos_start)) != std::string::npos) {
263 symbol = symbol_expr.substr(pos_start, pos_end - pos_start);
264 pos_start = pos_end + delim_len;
265 (void)symbol_vec.emplace_back(symbol);
266 }
267 (void)symbol_vec.emplace_back(symbol_expr.substr(pos_start));
268 } else {
269 (void)symbol_vec.emplace_back(symbol_expr);
270 }
271
272 LocVector host_loc_vec;
273 for (auto symbol : symbol_vec) {
274 if (std::all_of(symbol.begin(), symbol.end(), [](char c) { return std::isdigit(c); })) {
275 // for number '32', save a pair of <M, number>, where M = num of inputs + outputs
276 // M must be greater than any host_loc index, so it can be a flag for pure numbers
277 auto number_pair = std::make_pair(pure_num_flag, IntToSize(std::stoi(symbol)));
278 (void)host_loc_vec.emplace_back(number_pair);
279 } else if (host_loc_map_.find(symbol) != host_loc_map_.end()) {
280 // for symbol 's0', save its location in host shape as a pair of <i, j>
281 (void)host_loc_vec.emplace_back(host_loc_map_[symbol]);
282 } else {
283 MS_EXCEPTION(RuntimeError) << "For " << kernel_name_ << ", symbol '" << symbol
284 << "' of device shape is not in host shape.";
285 }
286 }
287 return host_loc_vec;
288 }
289
InitJsonShapeInformation()290 void AkgKernelImplInfo::InitJsonShapeInformation() {
291 // Initialize device shape list using -1 for unknown dims
292 // Record map <device_loc, symbol>
293 unordered_map<std::pair<size_t, size_t>, string, PairHash> device_loc_map;
294 for (size_t i = 0; i < parsed_js_[kDeviceShapes].size(); i++) {
295 vector<int64_t> device_tensor_shape;
296 auto device_tensor_rank = parsed_js_[kDeviceShapes][i].size();
297 max_shape_rank_ = std::max<size_t>(max_shape_rank_, device_tensor_rank);
298
299 for (size_t j = 0; j < device_tensor_rank; j++) {
300 string device_shape_str = parsed_js_[kDeviceShapes][i][j];
301 if (std::all_of(device_shape_str.begin(), device_shape_str.end(), [](char c) { return std::isdigit(c); })) {
302 (void)device_tensor_shape.emplace_back(std::stoi(device_shape_str));
303 } else {
304 (void)device_tensor_shape.emplace_back(-1);
305 auto device_loc = std::make_pair(i, j);
306 device_loc_map[device_loc] = device_shape_str;
307 }
308 }
309
310 (void)device_shape_list_.emplace_back(device_tensor_shape);
311 }
312 // Record map <symbol, host_loc>
313 for (int i = static_cast<int>(parsed_js_[kHostShapes].size()) - 1; i >= 0; i--) {
314 for (size_t j = 0; j < parsed_js_[kHostShapes][i].size(); j++) {
315 string shape_str = parsed_js_[kHostShapes][i][j];
316 if ((!std::all_of(shape_str.begin(), shape_str.end(), [](char c) { return std::isdigit(c); })) &&
317 (host_loc_map_.find(shape_str) == host_loc_map_.end())) {
318 host_loc_map_[shape_str] = std::make_pair(i, j);
319 }
320 }
321 }
322 // Get map <device_loc, host_loc>
323 for (auto item : device_loc_map) {
324 device_host_shape_loc_[item.first] = GetHostLocationVec(item.second, parsed_js_[kHostShapes].size());
325 }
326 MS_LOG(INFO) << "Done InitJsonShapeInformation for " << kernel_name_;
327 }
328
InitJsonMappingInformation()329 void AkgKernelImplInfo::InitJsonMappingInformation() {
330 // Create map <prime, var> where var is the dynamic tile size and prime is its unique id
331 runtime_vars_.clear();
332 sorted_runtime_vars_.clear();
333 for (size_t i = 0; i < parsed_js_[kRuntimeVars].size(); i++) {
334 RuntimeVarPtr v = std::make_shared<RuntimeVar>();
335 v->argIndex = parsed_js_[kRuntimeVars][i][v->ArgIndexKey()];
336 v->expr = parsed_js_[kRuntimeVars][i][v->ExprKey()];
337 v->mark = RuntimeVar::mark_table_[parsed_js_[kRuntimeVars][i].value(v->MarkKey(), "unknown")];
338 v->mapDim = parsed_js_[kRuntimeVars][i][v->MapDimKey()];
339 v->mapping = parsed_js_[kRuntimeVars][i][v->MappingKey()];
340 v->prime = parsed_js_[kRuntimeVars][i][v->PrimeKey()];
341 runtime_vars_[v->prime] = v;
342 sorted_runtime_vars_.emplace_back(std::make_pair(v->prime, v));
343 }
344
345 // Sort the map according to arg index of var
346 std::sort(sorted_runtime_vars_.begin(), sorted_runtime_vars_.end(), RuntimeVarCompare());
347
348 // Init mapping info with the staic mapping size and dynamic mapping size will be calculate during resize
349 init_mapping_info_ = MappingInfo();
350 init_mapping_info_.solve_order_id = {AKG_KERNEL_MOD_TX_IDX, AKG_KERNEL_MOD_TY_IDX, AKG_KERNEL_MOD_TZ_IDX,
351 AKG_KERNEL_MOD_BX_IDX, AKG_KERNEL_MOD_BY_IDX, AKG_KERNEL_MOD_BZ_IDX};
352
353 // Initialize mapping info as a vector. Only store dividend number for unknown mapping args
354 // Record map <unknown map ard id, host_loc>
355 init_mapping_.clear();
356 map_arg_list_ = {kBlockIdxX, kBlockIdxY, kBlockIdxZ, kThreadIdxX, kThreadIdxY, kThreadIdxZ};
357 for (size_t i = 0; i < map_arg_list_.size(); i++) {
358 auto map_arg = map_arg_list_[i];
359 if (parsed_js_[map_arg].is_number()) {
360 uint32_t map_size = parsed_js_[map_arg];
361 (void)init_mapping_.emplace_back(map_size);
362 auto it = runtime_vars_.find(map_size);
363 if (it == runtime_vars_.end()) {
364 // update static mapping
365 init_mapping_info_.UpdateCurrMapSize(i, map_size);
366 } else {
367 it->second->curr_map_id = i;
368 }
369 } else if (CheckJsonValueFormat(map_arg)) {
370 string divisor_symbol = parsed_js_[map_arg][0];
371 auto dividend = parsed_js_[map_arg][1];
372 (void)init_mapping_.emplace_back(static_cast<uint32_t>(dividend));
373 unknown_map_loc_[i] = GetHostLocationVec(divisor_symbol, parsed_js_[kHostShapes].size());
374 unknown_map_symbol_[i] = divisor_symbol;
375 product_var_[dividend] = divisor_symbol;
376 } else {
377 MS_EXCEPTION(RuntimeError) << "Mapping info format error.";
378 return;
379 }
380 }
381
382 if (runtime_vars_.empty()) {
383 return;
384 }
385 for (auto it = parsed_js_.begin(); it != parsed_js_.end(); ++it) {
386 std::string key = it.key();
387 if (key.find(kLocalPrefix) != std::string::npos && CheckJsonValueFormat(key)) {
388 string divisor_symbol = parsed_js_[key][0];
389 auto prime = static_cast<int>(parsed_js_[key][1]);
390 local_upper_bound_[prime] = host_loc_map_[divisor_symbol];
391 local_upper_bound_symbol_[prime] = divisor_symbol;
392 product_var_[prime] = divisor_symbol;
393 }
394 }
395
396 // update relationship: product = prime0 * prime1
397 for (const auto &kv : runtime_vars_) {
398 int prime0 = kv.first;
399 if (prime0 <= 1) continue;
400 if (runtime_vars_.find(prime0) != runtime_vars_.end()) {
401 for (const auto &kv2 : runtime_vars_) {
402 int product = kv2.first;
403 if (product > 0 && product != prime0 && product % prime0 == 0) {
404 product_var_[prime0] = product_var_[product];
405 product_var_[product / prime0] = product_var_[product];
406 std::vector<int> vars({prime0, product / prime0});
407 related_values_[product] = vars;
408 }
409 }
410 }
411 }
412 MS_LOG(INFO) << "Done InitJsonMappingInformation for " << kernel_name_;
413 }
414
preprocessDynamicReduceTiling()415 void AkgKernelImplInfo::preprocessDynamicReduceTiling() {
416 // keep static reduce length
417 static_reduce_length_ = parsed_js_[kSupportInfo][kReduceSizeStatic];
418
419 // collect thread-level idx
420 runtime_threads_order_.clear();
421 for (const auto &kv : runtime_vars_) {
422 auto var = kv.second;
423 // whether the runtime var has "thread" mark
424 if (var->prime <= 0 ||
425 (var->mark < AKG_DYN_MARK_THREAD_LOWER_BOUND || var->mark >= AKG_DYN_MARK_THREAD_UPPER_BOUND))
426 continue;
427 runtime_threads_order_.push_back(kv.first);
428 }
429
430 // sort tiling orders
431 const std::vector<int> orders = {10, 11, 12, 13, 20, 21, 22, 30, 40}; // mark order
432 template_tiling_order_.clear();
433 for (auto order : orders) {
434 for (auto kv : runtime_vars_) {
435 auto var = kv.second;
436 if (var->prime <= 0 || var->mark != order) continue;
437 template_tiling_order_.push_back(std::make_pair(kv.first, var->mark));
438 }
439 }
440
441 // build a map from prime number to mapping idx
442 prime_to_mapping_idx_.clear();
443 prime_to_mapping_dividend_.clear();
444 for (const auto &p : template_tiling_order_) {
445 int prime = p.first;
446 prime_to_mapping_idx_[prime] = -1;
447 prime_to_mapping_dividend_[prime] = -1;
448 for (size_t i = 0; i < map_arg_list_.size(); i++) {
449 auto map_arg = map_arg_list_[i];
450 if (parsed_js_[map_arg].is_number() && parsed_js_[map_arg] == prime) {
451 prime_to_mapping_idx_[prime] = i;
452 break;
453 } else if (CheckJsonValueFormat(map_arg) && parsed_js_[map_arg][1] == prime) {
454 prime_to_mapping_dividend_[prime] = i;
455 break;
456 }
457 }
458 }
459
460 enable_atomic_ = parsed_js_[kSupportInfo][kEnableAtomic];
461 dyn_algorithm_ = AkgKernelImplInfo::algo_to_int_[parsed_js_[kSupportInfo][kDynAlgorithm]];
462 }
463
InitBeforeMapping()464 void AkgKernelImplInfo::InitBeforeMapping() {
465 thread_info_ = init_mapping_;
466 if (!runtime_vars_.empty()) {
467 solved_map_loc_.clear();
468 curr_mapping_info_ = init_mapping_info_;
469 auto [g, b] = GetProposalParallelSize(problem_size_, device_target_);
470 curr_mapping_info_.proposal_grid = g;
471 curr_mapping_info_.proposal_block = b;
472 }
473 }
474
UpdateDynamicShapeTilingInfo()475 void DynamicTileImpl::UpdateDynamicShapeTilingInfo() {
476 if (runtime_vars_.empty()) {
477 MS_LOG(DEBUG) << "Static Tile " << kernel_name_;
478 return;
479 }
480 UpdateRuntimeVarUpperBound();
481 if (parsed_js_[kSupportInfo]["OperatorType"] == "Reduce") {
482 SolveDynamicReduction();
483 } else {
484 for (auto curr_id : curr_mapping_info_.solve_order_id) {
485 SolveDynamicTiling(curr_id);
486 }
487 }
488
489 for (auto it : sorted_runtime_vars_) {
490 arg_size_vec_.push_back(it.second->runtime_size);
491 }
492 MS_LOG(INFO) << "Done UpdateDynamicShapTilingInfo for " << kernel_name_;
493 }
494
GetFoldedShape(const LocVector & host_loc_vec)495 int64_t AkgKernelImplInfo::GetFoldedShape(const LocVector &host_loc_vec) {
496 auto folded_shape = 1;
497 for (auto host_loc : host_loc_vec) {
498 auto curr_shape = 1;
499 if (host_loc.first == shape_list_.size()) {
500 // pure number pair <M, number>
501 curr_shape = SizeToInt(host_loc.second);
502 } else if (shape_list_[host_loc.first].size() != 0) {
503 curr_shape = shape_list_[host_loc.first][host_loc.second];
504 }
505 if (folded_shape > INT64_MAX / curr_shape) {
506 MS_EXCEPTION(RuntimeError) << "For " << kernel_name_ << ", the product of shapes, " << folded_shape << " and "
507 << curr_shape << ", exceeds INT64_MAX.";
508 }
509 folded_shape *= curr_shape;
510 }
511 return folded_shape;
512 }
513
UpdateDynamicShapeMappingInfo()514 void AkgKernelImplInfo::UpdateDynamicShapeMappingInfo() {
515 for (auto item : unknown_map_loc_) {
516 auto thread_info_id = item.first;
517 if (solved_map_loc_.count(thread_info_id)) {
518 continue;
519 }
520 if (thread_info_id >= thread_info_.size()) {
521 MS_EXCEPTION(RuntimeError) << "Unknown thread arg index should not exceed thread_info length, "
522 << "which is 6 (Grid.X/Y/Z, Block.X/Y/Z).";
523 }
524 auto host_loc_vec = item.second;
525 auto dim_size = GetFoldedShape(host_loc_vec);
526 auto tile_size = thread_info_[thread_info_id];
527 auto dim_size_float = static_cast<float>(dim_size);
528 auto tile_size_float = static_cast<float>(tile_size);
529 thread_info_[thread_info_id] = static_cast<uint32_t>(std::ceil(dim_size_float / tile_size_float));
530 }
531
532 MS_LOG(DEBUG) << "For " << kernel_name_ << ", thread_info = " << thread_info_;
533 MS_LOG(INFO) << "Done UpdateDynamicShapeMappingInfo for " << kernel_name_;
534 }
535
GetDeviceArgSizeVec()536 void AkgKernelImplInfo::GetDeviceArgSizeVec() {
537 arg_size_vec_.clear();
538 std::vector<std::vector<int64_t>> device_shape(device_shape_list_);
539 // Update each unknown value in device_shape_list to get real device shape
540 for (auto item : device_host_shape_loc_) {
541 auto device_loc = item.first;
542 auto host_loc_vec = item.second;
543 device_shape[device_loc.first][device_loc.second] = GetFoldedShape(host_loc_vec);
544 }
545 problem_size_ = 1;
546 for (size_t i = 0; i < device_shape.size(); i++) {
547 MS_LOG(DEBUG) << "For " << kernel_name_ << ", input[" << i << "]: host_shape = " << shape_list_[i]
548 << ", device_shape = " << device_shape[i];
549 arg_size_vec_.push_back(kRemove); // useless in memref
550 arg_size_vec_.push_back(kKeep); // data ptr
551 arg_size_vec_.push_back(0); // offset
552 auto device_tensor_shape = device_shape[i];
553 int64_t tensor_size = 1;
554 for (auto item : device_tensor_shape) {
555 if (item <= 0) {
556 MS_EXCEPTION(RuntimeError) << "Shape still have negative value for kernel: " << kernel_name_
557 << "with host shape[" << i << "] = " << shape_list_[i] << ", device_tensor_shape["
558 << i << "] = " << device_tensor_shape;
559 }
560 arg_size_vec_.push_back(item);
561 tensor_size *= item;
562 }
563 problem_size_ = std::max<int64_t>(problem_size_, tensor_size);
564 vector<int64_t> strides(device_tensor_shape.size(), 1);
565 for (int j = SizeToInt(device_tensor_shape.size()) - 2; j >= 0; j--) {
566 strides[j] = strides[j + 1] * device_tensor_shape[j + 1];
567 }
568 for (auto item : strides) {
569 arg_size_vec_.push_back(item);
570 }
571 }
572 MS_LOG(DEBUG) << "For " << kernel_name_ << ", arg_size_vec = " << arg_size_vec_;
573 MS_LOG(INFO) << "Done GetDeviceArgSizeVec for " << kernel_name_;
574 }
575
UpdateRuntimeVarUpperBound()576 void DynamicTileImpl::UpdateRuntimeVarUpperBound() {
577 // Comes from `Block/Thread: [upper_bound, prime]`
578 for (const auto &item : unknown_map_loc_) {
579 auto thread_info_id = item.first;
580 auto host_loc_vec = item.second;
581 auto tile_size = thread_info_[thread_info_id];
582 auto it = runtime_vars_.find(tile_size);
583 if (it != runtime_vars_.end()) {
584 auto dim_size = GetFoldedShape(host_loc_vec);
585 it->second->upper_bound = dim_size;
586 it->second->outer_map_id = thread_info_id;
587 auto symbol = unknown_map_symbol_[thread_info_id];
588 axis_length_left_[symbol] = dim_size;
589 }
590 }
591
592 // Comes from `Seq: [upper_bound, prime]`
593 for (const auto &it : local_upper_bound_) {
594 auto prime = it.first;
595 if (runtime_vars_.find(prime) != runtime_vars_.end()) {
596 auto host_loc = it.second;
597 auto dim_size = shape_list_[host_loc.first][host_loc.second];
598 runtime_vars_[prime]->upper_bound = dim_size;
599 auto symbol = local_upper_bound_symbol_[prime];
600 axis_length_left_[symbol] = dim_size;
601 }
602 }
603 }
604
UpdateMapping(int curr_id,int64_t map_size,int64_t prime)605 void DynamicTileImpl::UpdateMapping(int curr_id, int64_t map_size, int64_t prime) {
606 // skip when mark is seq.x
607 if (curr_id != -1) {
608 thread_info_[curr_id] = map_size;
609 solved_map_loc_.insert(curr_id);
610 curr_mapping_info_.UpdateCurrMapSize(curr_id, map_size);
611 }
612 if (runtime_vars_.find(prime) == runtime_vars_.end()) {
613 return;
614 }
615 runtime_vars_[prime]->runtime_size = map_size;
616 int64_t neg_prime = -prime;
617 if (runtime_vars_.find(neg_prime) != runtime_vars_.end()) {
618 runtime_vars_[neg_prime]->runtime_size = -map_size;
619 }
620 }
621
TileSizeOpt(const RuntimeVarPtr & var,int64_t dyn_tile_size)622 int64_t DynamicTileImpl::TileSizeOpt(const RuntimeVarPtr &var, int64_t dyn_tile_size) {
623 // Currently we only optimize tile size for elementwise ops based on problem size.
624 bool map_outer_grid = var->curr_map_id >= AKG_KERNEL_MOD_BX_IDX && var->curr_map_id <= AKG_KERNEL_MOD_BZ_IDX;
625 if (map_outer_grid) {
626 auto rest_grid = std::max<int64_t>(1, curr_mapping_info_.proposal_grid / curr_mapping_info_.total_alloc_grid);
627 dyn_tile_size = std::min<int64_t>(dyn_tile_size, rest_grid);
628 } else {
629 auto rest_block = std::max<int64_t>(1, curr_mapping_info_.proposal_block / curr_mapping_info_.total_alloc_block);
630 if (var->curr_map_id == AKG_KERNEL_MOD_TY_IDX && var->upper_bound % rest_block != 0) {
631 rest_block = 1;
632 }
633 dyn_tile_size = std::min<int64_t>(dyn_tile_size, rest_block);
634 }
635 return dyn_tile_size;
636 }
637
SolveDynamicReduction()638 void DynamicTileImpl::SolveDynamicReduction() {
639 int proper_block = -1, proper_thread = -1, proper_seq = -1;
640 int total_red_size =
641 std::accumulate(runtime_threads_order_.begin(), runtime_threads_order_.end(), 1,
642 [&](int total, int prime) { return total * axis_length_left_[product_var_[prime]]; });
643
644 if (dyn_algorithm_ == AKG_DYN_ALGO_REDUCE_X) {
645 GetProperReduceXConfig(total_red_size, enable_atomic_, &proper_block, &proper_thread, &proper_seq);
646 } else if (dyn_algorithm_ == AKG_DYN_ALGO_REDUCE_Y) {
647 GetProperReduceYConfig(total_red_size, enable_atomic_, &proper_block, &proper_seq);
648 }
649
650 if (dyn_algorithm_ != AKG_DYN_ALGO_REDUCE_X) {
651 proper_thread = kNum32;
652 }
653
654 proper_thread = (proper_thread - 1) / curr_mapping_info_.total_alloc_block + 1; // remove used
655 for (const auto &p : template_tiling_order_) {
656 int prime = p.first;
657 int mark = p.second;
658 int upper_bound = -1;
659 auto symbol = product_var_[prime];
660 auto current_length = axis_length_left_[symbol];
661 if (AKG_DYN_MARK_THREAD_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_THREAD_UPPER_BOUND) {
662 upper_bound = std::min<int>(current_length, (MAX_THREAD_NUM / curr_mapping_info_.total_alloc_block));
663 } else {
664 upper_bound = current_length;
665 }
666 int tile_size = -1;
667 if (mark == AKG_DYN_MARK_PRODUCT) {
668 tile_size =
669 runtime_vars_[related_values_[prime][0]]->runtime_size * runtime_vars_[related_values_[prime][1]]->runtime_size;
670 } else {
671 GetReduceTileSize(mark, upper_bound, proper_thread, proper_seq, &tile_size);
672 if (AKG_DYN_MARK_SEQ_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_SEQ_UPPER_BOUND) {
673 proper_seq = proper_seq / tile_size;
674 }
675 }
676 // update mapping
677 if (prime_to_mapping_idx_[prime] != -1) {
678 // scenario 1: BlockIdx.x = prime
679 auto curr_idx = prime_to_mapping_idx_[prime];
680 thread_info_[curr_idx] = tile_size;
681 solved_map_loc_.insert(curr_idx);
682 } else if (prime_to_mapping_dividend_[prime] != -1) {
683 // scenario 2: BlockIdx.x = symbol / prime
684 // NOTE: since we know thread_info_'s format here, we only use tile_size
685 // to represent both divier and dividend. update var name later.
686 auto curr_idx = prime_to_mapping_dividend_[prime];
687 thread_info_[curr_idx] = tile_size;
688 }
689 if (runtime_vars_.find(prime) != runtime_vars_.end()) {
690 runtime_vars_[prime]->runtime_size = tile_size;
691 int64_t neg_prime = -prime;
692 if (runtime_vars_.find(neg_prime) != runtime_vars_.end()) {
693 runtime_vars_[neg_prime]->runtime_size = -tile_size;
694 }
695 }
696 axis_length_left_[symbol] = (current_length - 1) / tile_size + 1;
697 }
698 }
699
SolveDynamicTiling(size_t curr_id)700 void DynamicTileImpl::SolveDynamicTiling(size_t curr_id) {
701 auto prime = thread_info_[curr_id];
702 auto it = runtime_vars_.find(prime);
703 if (it == runtime_vars_.end()) {
704 return;
705 }
706 auto var = it->second;
707 if (var->curr_map_id != static_cast<int>(curr_id)) {
708 if (var->outer_map_id != static_cast<int>(curr_id)) {
709 MS_EXCEPTION(RuntimeError) << "Unknown var: " << var->ToString() << "; Cannot map currId: " << curr_id;
710 } else {
711 // In this branch, dividend is stored in thread_info_ and dividend equals to prime number
712 // means that the mapping of `curr_id` equals to the upper bound ceildiv the `runtime_size`.
713 auto dividend = static_cast<uint32_t>((var->upper_bound - 1) / var->runtime_size) + 1;
714 UpdateMapping(curr_id, dividend, 0);
715 }
716 return;
717 }
718 auto map_limit = curr_mapping_info_.GetMapLimit(curr_id, device_target_);
719 auto upper_bound = std::min<uint32_t>(map_limit, var->upper_bound);
720 if (upper_bound < 1) {
721 MS_EXCEPTION(RuntimeError) << " Invalid upper_bound of runtime var : " << var->ToString();
722 }
723
724 // Init dynamic tile size to the upper bound
725 int64_t dyn_tile_size = upper_bound;
726
727 // Add dynamic tiling strategy here
728 auto warp_num = std::max<int64_t>(1, var->upper_bound / WARP_SIZE);
729 bool map_outer_block = var->outer_map_id >= AKG_KERNEL_MOD_BX_IDX && var->outer_map_id <= AKG_KERNEL_MOD_BZ_IDX;
730 if (var->curr_map_id == AKG_KERNEL_MOD_TX_IDX && var->upper_bound % WARP_SIZE != 0) {
731 dyn_tile_size = WARP_SIZE * warp_num;
732 } else if (var->curr_map_id == AKG_KERNEL_MOD_TY_IDX && var->upper_bound % WARP_SIZE != 0) {
733 if (map_outer_block && curr_mapping_info_.total_alloc_grid * WARP_SIZE < ELEM_BEST_GRID_SIZE) {
734 dyn_tile_size = 1;
735 } else {
736 dyn_tile_size = (WARP_SIZE / WARP_ALLOC_GRAN) * warp_num;
737 }
738 }
739 dyn_tile_size = std::min<int64_t>(dyn_tile_size, upper_bound);
740 dyn_tile_size = TileSizeOpt(var, dyn_tile_size);
741 UpdateMapping(curr_id, dyn_tile_size, prime);
742 }
743 } // namespace kernel
744 } // namespace mindspore
745