• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/gpu/kernel/dynamic_akg/dynamic_utils.h"
18 #include <fstream>
19 #include <algorithm>
20 #include <map>
21 #include "nlohmann/json.hpp"
22 #include "utils/ms_utils.h"
23 #include "kernel/framework_utils.h"
24 #include "mindspore/ccsrc/include/common/debug/common.h"
25 #include "plugin/device/gpu/hal/device/gpu_common.h"
26 
27 namespace mindspore {
28 namespace kernel {
29 using std::fstream;
30 using std::string;
31 using std::unordered_map;
32 using std::vector;
33 namespace {
34 constexpr auto kV100Device = "v100";
35 constexpr auto kA100Device = "a100";
36 constexpr auto kSharedMem = "shared_mem";
37 constexpr auto kRegMem = "reg_mem";
38 constexpr auto kLocalPrefix = "Seq";
39 constexpr auto kHostShapes = "hostShapes";
40 constexpr auto kDeviceShapes = "deviceShapes";
41 constexpr auto kRuntimeVars = "runtimeVars";
42 constexpr auto kTargetInfo = "targetInfo";
43 constexpr auto kSupportInfo = "SupportInfo";
44 constexpr auto kReduceSizeStatic = "ReduceSizeStatic";
45 constexpr auto kDynAlgorithm = "DynAlgorithm";
46 constexpr auto kEnableAtomic = "EnableAtomic";
47 constexpr auto kMultiply = "*";
48 
49 constexpr auto kBlockIdxX = "blockIdx.x";
50 constexpr auto kBlockIdxY = "blockIdx.y";
51 constexpr auto kBlockIdxZ = "blockIdx.z";
52 constexpr auto kThreadIdxX = "threadIdx.x";
53 constexpr auto kThreadIdxY = "threadIdx.y";
54 constexpr auto kThreadIdxZ = "threadIdx.z";
55 constexpr int64_t kRemove = -100000;
56 constexpr int64_t kKeep = -99999;
57 const int AKG_KERNEL_MOD_BX_IDX = 0;
58 const int AKG_KERNEL_MOD_BY_IDX = 1;
59 const int AKG_KERNEL_MOD_BZ_IDX = 2;
60 const int AKG_KERNEL_MOD_TX_IDX = 3;
61 const int AKG_KERNEL_MOD_TY_IDX = 4;
62 const int AKG_KERNEL_MOD_TZ_IDX = 5;
63 const int WARP_SIZE = 32;
64 const int WARP_ALLOC_GRAN = 4;
65 const int ELEM_BEST_GRID_SIZE = 512;
66 const int MAX_THREAD_NUM = 1024;
67 const int k48KB = 49152;
68 const int k64KB = 65536;
69 const int kNum80 = 80;
70 const int kNum108 = 108;
71 const int kNum512 = 512;
72 const int kNum1024 = 1024;
73 const int kNum2 = 2;
74 const int kNum16 = 16;
75 const int kNum32 = 32;
76 const int kNum64 = 64;
77 const int kNum256 = 256;
78 
79 const int AKG_DYN_ALGO_REDUCE_X = 10;
80 const int AKG_DYN_ALGO_REDUCE_Y = 11;
81 const int AKG_DYN_ALGO_REDUCE_SMALL = 12;
82 const int AKG_DYN_MARK_THREAD_LOWER_BOUND = 10;
83 const int AKG_DYN_MARK_THREAD_UPPER_BOUND = 20;
84 const int AKG_DYN_MARK_SEQ_LOWER_BOUND = 20;
85 const int AKG_DYN_MARK_SEQ_UPPER_BOUND = 30;
86 const int AKG_DYN_MARK_ONE = 30;
87 const int AKG_DYN_MARK_PRODUCT = 40;
88 const int AKG_DYN_MARK_UNKNOWN = 999;
89 }  // namespace
90 struct RuntimeVarCompare {
operator ()mindspore::kernel::RuntimeVarCompare91   bool operator()(const std::pair<uint32_t, RuntimeVarPtr> &a, const std::pair<uint32_t, RuntimeVarPtr> &b) const {
92     return a.second->argIndex < b.second->argIndex;
93   }
94 };
95 
96 std::unordered_map<std::string, int> AkgKernelImplInfo::algo_to_int_ = {{"reduce-x", AKG_DYN_ALGO_REDUCE_X},
97                                                                         {"reduce-y", AKG_DYN_ALGO_REDUCE_Y},
98                                                                         {"reduce-small", AKG_DYN_ALGO_REDUCE_SMALL}};
99 
100 std::unordered_map<std::string, int> RuntimeVar::mark_table_ = {{"unknown", 999},
101                                                                 {"reduce-thread-last", 10},
102                                                                 {"reduce-thread", 11},
103                                                                 {"parallel-thread-last", 12},
104                                                                 {"parallel-thread", 13},
105                                                                 {"reduce-y-seq", 20},
106                                                                 {"reduce-x-seq", 21},
107                                                                 {"parallel-seq", 22},
108                                                                 {"1", 30},
109                                                                 {"product", 40}};
110 
GetProperReduceXConfig(int redSize,const bool & useAtomicFlag,int * block_num,int * thread_num,int * seq_num)111 inline void GetProperReduceXConfig(int redSize, const bool &useAtomicFlag, int *block_num, int *thread_num,
112                                    int *seq_num) {
113   const int acc_num = 4;
114   *block_num = useAtomicFlag ? (redSize - 1) / MAX_THREAD_NUM + 1 : 1;
115   redSize = useAtomicFlag ? (redSize - 1) / (*block_num) + 1 : redSize;
116   (*thread_num) = redSize < kNum32 ? redSize : kNum32;
117   while ((*thread_num) * kNum2 * acc_num <= redSize && (*thread_num) <= MAX_THREAD_NUM) {
118     (*thread_num) *= kNum2;
119   }
120   (*seq_num) = (redSize - 1) / ((*thread_num) * (*block_num)) + 1;
121 }
122 
GetProperReduceYConfig(const int & redSize,const bool & useAtomicFlag,int * block_num,int * seq_num)123 inline void GetProperReduceYConfig(const int &redSize, const bool &useAtomicFlag, int *block_num, int *seq_num) {
124   *block_num = 1;
125   *seq_num = redSize;
126   if (useAtomicFlag) {
127     if (redSize < kNum256) {
128       *seq_num = kNum16;
129     } else if (redSize < kNum1024) {
130       *seq_num = kNum32;
131     } else {
132       *seq_num = kNum64;
133     }
134     *block_num = (redSize - 1) / (*seq_num) + 1;
135   }
136 }
137 
GetReduceTileSize(const int & mark,const int & upper_bound,const int & proper_thread,const int & proper_seq,int * tile_size)138 inline void GetReduceTileSize(const int &mark, const int &upper_bound, const int &proper_thread, const int &proper_seq,
139                               int *tile_size) {
140   if (AKG_DYN_MARK_THREAD_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_THREAD_UPPER_BOUND) {
141     *tile_size = std::min(upper_bound, proper_thread);
142   } else if (AKG_DYN_MARK_SEQ_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_SEQ_UPPER_BOUND) {
143     *tile_size = std::min(upper_bound, proper_seq);
144   } else if (mark == AKG_DYN_MARK_ONE) {
145     *tile_size = 1;
146   } else {
147     *tile_size = -1;
148   }
149 }
150 
GetProposalParallelSize(int problemSize,const std::string & device_target)151 static std::pair<int, int> GetProposalParallelSize(int problemSize, const std::string &device_target) {
152   GpuInfo &gpu_info = GpuInfo::GetInstance(device_target);
153   int proposedGrid = 1;
154   int proposedBlock = 1;
155   auto numSm = gpu_info.GetNumSm();
156   auto threadCoef = gpu_info.GetThreadCoef();
157   auto warpSizes = gpu_info.GetWarpSizes();
158   auto activeBlocksPerSm = gpu_info.GetActiveBlocksPerSm();
159   auto totalBlocks = gpu_info.GetTotalAvailableBlocks();
160   if (problemSize <= warpSizes) {
161     proposedBlock = warpSizes;
162   } else if (problemSize <= warpSizes * numSm) {
163     proposedBlock = warpSizes;
164     proposedGrid = numSm;
165   } else if (problemSize <= warpSizes * threadCoef.first * numSm * activeBlocksPerSm.first) {
166     proposedBlock = warpSizes * threadCoef.first;
167     proposedGrid = numSm * activeBlocksPerSm.first;
168   } else if (problemSize <= warpSizes * threadCoef.second * numSm * activeBlocksPerSm.second) {
169     proposedBlock = warpSizes * threadCoef.second;
170     proposedGrid = numSm * activeBlocksPerSm.second;
171   } else if (problemSize <= warpSizes * threadCoef.second * numSm * activeBlocksPerSm.second * numSm) {
172     proposedBlock = totalBlocks;
173     proposedGrid = numSm * activeBlocksPerSm.second;
174   } else {
175     // extremely large shape
176     proposedBlock = totalBlocks;
177     proposedGrid = numSm * activeBlocksPerSm.second * kNum2;
178   }
179   return std::make_pair(proposedGrid, proposedBlock);
180 }
181 
GetMapLimit(size_t id,const std::string & device_target)182 uint32_t MappingInfo::GetMapLimit(size_t id, const std::string &device_target) {
183   if (id < AKG_KERNEL_MOD_BX_IDX || id > AKG_KERNEL_MOD_TZ_IDX) {
184     MS_EXCEPTION(RuntimeError) << "Map id should be in range [" << AKG_KERNEL_MOD_BX_IDX << ", "
185                                << AKG_KERNEL_MOD_TZ_IDX << "], but got " << id;
186   }
187   GpuInfo &gpu_info = GpuInfo::GetInstance(device_target);
188   if (id >= AKG_KERNEL_MOD_TX_IDX) {
189     auto max_block = gpu_info.GetMaxBlocks();
190     return std::min<uint32_t>(max_block[id - AKG_KERNEL_MOD_TX_IDX],
191                               gpu_info.GetTotalAvailableBlocks() / total_alloc_block);
192   } else {
193     auto max_grid = gpu_info.GetMaxGrids();
194     return max_grid[id];
195   }
196 }
197 
UpdateCurrMapSize(size_t id,uint32_t map_size)198 void MappingInfo::UpdateCurrMapSize(size_t id, uint32_t map_size) {
199   if (id < AKG_KERNEL_MOD_BX_IDX || id > AKG_KERNEL_MOD_TZ_IDX) {
200     MS_EXCEPTION(RuntimeError) << "Map id should be in range [" << AKG_KERNEL_MOD_BX_IDX << ", "
201                                << AKG_KERNEL_MOD_TZ_IDX << "], but got " << id;
202   }
203   if (id >= AKG_KERNEL_MOD_TX_IDX) {
204     curr_block[id - AKG_KERNEL_MOD_TX_IDX] = map_size;
205     total_alloc_block *= map_size;
206   } else {
207     curr_grid[id] = map_size;
208     total_alloc_grid *= map_size;
209   }
210 }
211 
InitGpuMemoryLimit(const std::string & device_type)212 void GpuInfo::InitGpuMemoryLimit(const std::string &device_type) {
213   auto CollectLimit = [this, &device_type](const std::string &scope, GpuMemScope mem) {
214     if (device_type == kV100Device) {
215       if (scope == kSharedMem) {
216         gpuMemLimit[mem] = k48KB;
217       } else if (scope == kRegMem) {
218         gpuMemLimit[mem] = k64KB;
219       }
220     } else if (device_type == kA100Device) {
221       if (scope == kSharedMem) {
222         gpuMemLimit[mem] = k64KB;
223       } else if (scope == kRegMem) {
224         gpuMemLimit[mem] = k64KB;
225       }
226     }
227   };
228   CollectLimit(kSharedMem, MEM_SCOPE_SHARED);
229   CollectLimit(kRegMem, MEM_SCOPE_LOCAL);
230   gpuMemLimit[MEM_SCOPE_GM] = 0;
231 }
232 
InitGpuComputeCapability(const std::string & device_type)233 void GpuInfo::InitGpuComputeCapability(const std::string &device_type) {
234   if (device_type == kV100Device) {
235     numSm = kNum80;
236     totalAvailableBlocks = kNum512;
237   } else if (device_type == kA100Device) {
238     numSm = kNum108;
239     totalAvailableBlocks = kNum1024;
240   }
241 }
242 
AkgKernelImplInfo(const std::string & kernel_name,nlohmann::json json)243 AkgKernelImplInfo::AkgKernelImplInfo(const std::string &kernel_name, nlohmann::json json) {
244   kernel_name_ = kernel_name;
245   parsed_js_ = json;
246   device_target_ = parsed_js_[kTargetInfo];
247   if (device_target_.empty()) {
248     device_target_ = kV100Device;
249   }
250 }
251 
GetHostLocationVec(std::string symbol_expr,const size_t pure_num_flag)252 LocVector AkgKernelImplInfo::GetHostLocationVec(std::string symbol_expr, const size_t pure_num_flag) {
253   std::string delimiter = kMultiply;
254   std::vector<std::string> symbol_vec;
255   if (symbol_expr.find(delimiter) != std::string::npos) {
256     // multiplication expr after folding like 's0*1024*s1'
257     size_t pos_start = 0;
258     size_t pos_end;
259     size_t delim_len = 1;
260     std::string symbol;
261     // split each symbol or number and store in a list
262     while ((pos_end = symbol_expr.find(delimiter, pos_start)) != std::string::npos) {
263       symbol = symbol_expr.substr(pos_start, pos_end - pos_start);
264       pos_start = pos_end + delim_len;
265       (void)symbol_vec.emplace_back(symbol);
266     }
267     (void)symbol_vec.emplace_back(symbol_expr.substr(pos_start));
268   } else {
269     (void)symbol_vec.emplace_back(symbol_expr);
270   }
271 
272   LocVector host_loc_vec;
273   for (auto symbol : symbol_vec) {
274     if (std::all_of(symbol.begin(), symbol.end(), [](char c) { return std::isdigit(c); })) {
275       // for number '32', save a pair of <M, number>, where M = num of inputs + outputs
276       // M must be greater than any host_loc index, so it can be a flag for pure numbers
277       auto number_pair = std::make_pair(pure_num_flag, IntToSize(std::stoi(symbol)));
278       (void)host_loc_vec.emplace_back(number_pair);
279     } else if (host_loc_map_.find(symbol) != host_loc_map_.end()) {
280       // for symbol 's0', save its location in host shape as a pair of <i, j>
281       (void)host_loc_vec.emplace_back(host_loc_map_[symbol]);
282     } else {
283       MS_EXCEPTION(RuntimeError) << "For " << kernel_name_ << ", symbol '" << symbol
284                                  << "' of device shape is not in host shape.";
285     }
286   }
287   return host_loc_vec;
288 }
289 
InitJsonShapeInformation()290 void AkgKernelImplInfo::InitJsonShapeInformation() {
291   // Initialize device shape list using -1 for unknown dims
292   // Record map <device_loc, symbol>
293   unordered_map<std::pair<size_t, size_t>, string, PairHash> device_loc_map;
294   for (size_t i = 0; i < parsed_js_[kDeviceShapes].size(); i++) {
295     vector<int64_t> device_tensor_shape;
296     auto device_tensor_rank = parsed_js_[kDeviceShapes][i].size();
297     max_shape_rank_ = std::max<size_t>(max_shape_rank_, device_tensor_rank);
298 
299     for (size_t j = 0; j < device_tensor_rank; j++) {
300       string device_shape_str = parsed_js_[kDeviceShapes][i][j];
301       if (std::all_of(device_shape_str.begin(), device_shape_str.end(), [](char c) { return std::isdigit(c); })) {
302         (void)device_tensor_shape.emplace_back(std::stoi(device_shape_str));
303       } else {
304         (void)device_tensor_shape.emplace_back(-1);
305         auto device_loc = std::make_pair(i, j);
306         device_loc_map[device_loc] = device_shape_str;
307       }
308     }
309 
310     (void)device_shape_list_.emplace_back(device_tensor_shape);
311   }
312   // Record map <symbol, host_loc>
313   for (int i = static_cast<int>(parsed_js_[kHostShapes].size()) - 1; i >= 0; i--) {
314     for (size_t j = 0; j < parsed_js_[kHostShapes][i].size(); j++) {
315       string shape_str = parsed_js_[kHostShapes][i][j];
316       if ((!std::all_of(shape_str.begin(), shape_str.end(), [](char c) { return std::isdigit(c); })) &&
317           (host_loc_map_.find(shape_str) == host_loc_map_.end())) {
318         host_loc_map_[shape_str] = std::make_pair(i, j);
319       }
320     }
321   }
322   // Get map <device_loc, host_loc>
323   for (auto item : device_loc_map) {
324     device_host_shape_loc_[item.first] = GetHostLocationVec(item.second, parsed_js_[kHostShapes].size());
325   }
326   MS_LOG(INFO) << "Done InitJsonShapeInformation for " << kernel_name_;
327 }
328 
InitJsonMappingInformation()329 void AkgKernelImplInfo::InitJsonMappingInformation() {
330   // Create map <prime, var> where var is the dynamic tile size and prime is its unique id
331   runtime_vars_.clear();
332   sorted_runtime_vars_.clear();
333   for (size_t i = 0; i < parsed_js_[kRuntimeVars].size(); i++) {
334     RuntimeVarPtr v = std::make_shared<RuntimeVar>();
335     v->argIndex = parsed_js_[kRuntimeVars][i][v->ArgIndexKey()];
336     v->expr = parsed_js_[kRuntimeVars][i][v->ExprKey()];
337     v->mark = RuntimeVar::mark_table_[parsed_js_[kRuntimeVars][i].value(v->MarkKey(), "unknown")];
338     v->mapDim = parsed_js_[kRuntimeVars][i][v->MapDimKey()];
339     v->mapping = parsed_js_[kRuntimeVars][i][v->MappingKey()];
340     v->prime = parsed_js_[kRuntimeVars][i][v->PrimeKey()];
341     runtime_vars_[v->prime] = v;
342     sorted_runtime_vars_.emplace_back(std::make_pair(v->prime, v));
343   }
344 
345   // Sort the map according to arg index of var
346   std::sort(sorted_runtime_vars_.begin(), sorted_runtime_vars_.end(), RuntimeVarCompare());
347 
348   // Init mapping info with the staic mapping size and dynamic mapping size will be calculate during resize
349   init_mapping_info_ = MappingInfo();
350   init_mapping_info_.solve_order_id = {AKG_KERNEL_MOD_TX_IDX, AKG_KERNEL_MOD_TY_IDX, AKG_KERNEL_MOD_TZ_IDX,
351                                        AKG_KERNEL_MOD_BX_IDX, AKG_KERNEL_MOD_BY_IDX, AKG_KERNEL_MOD_BZ_IDX};
352 
353   // Initialize mapping info as a vector. Only store dividend number for unknown mapping args
354   // Record map <unknown map ard id, host_loc>
355   init_mapping_.clear();
356   map_arg_list_ = {kBlockIdxX, kBlockIdxY, kBlockIdxZ, kThreadIdxX, kThreadIdxY, kThreadIdxZ};
357   for (size_t i = 0; i < map_arg_list_.size(); i++) {
358     auto map_arg = map_arg_list_[i];
359     if (parsed_js_[map_arg].is_number()) {
360       uint32_t map_size = parsed_js_[map_arg];
361       (void)init_mapping_.emplace_back(map_size);
362       auto it = runtime_vars_.find(map_size);
363       if (it == runtime_vars_.end()) {
364         // update static mapping
365         init_mapping_info_.UpdateCurrMapSize(i, map_size);
366       } else {
367         it->second->curr_map_id = i;
368       }
369     } else if (CheckJsonValueFormat(map_arg)) {
370       string divisor_symbol = parsed_js_[map_arg][0];
371       auto dividend = parsed_js_[map_arg][1];
372       (void)init_mapping_.emplace_back(static_cast<uint32_t>(dividend));
373       unknown_map_loc_[i] = GetHostLocationVec(divisor_symbol, parsed_js_[kHostShapes].size());
374       unknown_map_symbol_[i] = divisor_symbol;
375       product_var_[dividend] = divisor_symbol;
376     } else {
377       MS_EXCEPTION(RuntimeError) << "Mapping info format error.";
378       return;
379     }
380   }
381 
382   if (runtime_vars_.empty()) {
383     return;
384   }
385   for (auto it = parsed_js_.begin(); it != parsed_js_.end(); ++it) {
386     std::string key = it.key();
387     if (key.find(kLocalPrefix) != std::string::npos && CheckJsonValueFormat(key)) {
388       string divisor_symbol = parsed_js_[key][0];
389       auto prime = static_cast<int>(parsed_js_[key][1]);
390       local_upper_bound_[prime] = host_loc_map_[divisor_symbol];
391       local_upper_bound_symbol_[prime] = divisor_symbol;
392       product_var_[prime] = divisor_symbol;
393     }
394   }
395 
396   // update relationship: product = prime0 * prime1
397   for (const auto &kv : runtime_vars_) {
398     int prime0 = kv.first;
399     if (prime0 <= 1) continue;
400     if (runtime_vars_.find(prime0) != runtime_vars_.end()) {
401       for (const auto &kv2 : runtime_vars_) {
402         int product = kv2.first;
403         if (product > 0 && product != prime0 && product % prime0 == 0) {
404           product_var_[prime0] = product_var_[product];
405           product_var_[product / prime0] = product_var_[product];
406           std::vector<int> vars({prime0, product / prime0});
407           related_values_[product] = vars;
408         }
409       }
410     }
411   }
412   MS_LOG(INFO) << "Done InitJsonMappingInformation for " << kernel_name_;
413 }
414 
preprocessDynamicReduceTiling()415 void AkgKernelImplInfo::preprocessDynamicReduceTiling() {
416   // keep static reduce length
417   static_reduce_length_ = parsed_js_[kSupportInfo][kReduceSizeStatic];
418 
419   // collect thread-level idx
420   runtime_threads_order_.clear();
421   for (const auto &kv : runtime_vars_) {
422     auto var = kv.second;
423     // whether the runtime var has "thread" mark
424     if (var->prime <= 0 ||
425         (var->mark < AKG_DYN_MARK_THREAD_LOWER_BOUND || var->mark >= AKG_DYN_MARK_THREAD_UPPER_BOUND))
426       continue;
427     runtime_threads_order_.push_back(kv.first);
428   }
429 
430   // sort tiling orders
431   const std::vector<int> orders = {10, 11, 12, 13, 20, 21, 22, 30, 40};  // mark order
432   template_tiling_order_.clear();
433   for (auto order : orders) {
434     for (auto kv : runtime_vars_) {
435       auto var = kv.second;
436       if (var->prime <= 0 || var->mark != order) continue;
437       template_tiling_order_.push_back(std::make_pair(kv.first, var->mark));
438     }
439   }
440 
441   // build a map from prime number to mapping idx
442   prime_to_mapping_idx_.clear();
443   prime_to_mapping_dividend_.clear();
444   for (const auto &p : template_tiling_order_) {
445     int prime = p.first;
446     prime_to_mapping_idx_[prime] = -1;
447     prime_to_mapping_dividend_[prime] = -1;
448     for (size_t i = 0; i < map_arg_list_.size(); i++) {
449       auto map_arg = map_arg_list_[i];
450       if (parsed_js_[map_arg].is_number() && parsed_js_[map_arg] == prime) {
451         prime_to_mapping_idx_[prime] = i;
452         break;
453       } else if (CheckJsonValueFormat(map_arg) && parsed_js_[map_arg][1] == prime) {
454         prime_to_mapping_dividend_[prime] = i;
455         break;
456       }
457     }
458   }
459 
460   enable_atomic_ = parsed_js_[kSupportInfo][kEnableAtomic];
461   dyn_algorithm_ = AkgKernelImplInfo::algo_to_int_[parsed_js_[kSupportInfo][kDynAlgorithm]];
462 }
463 
InitBeforeMapping()464 void AkgKernelImplInfo::InitBeforeMapping() {
465   thread_info_ = init_mapping_;
466   if (!runtime_vars_.empty()) {
467     solved_map_loc_.clear();
468     curr_mapping_info_ = init_mapping_info_;
469     auto [g, b] = GetProposalParallelSize(problem_size_, device_target_);
470     curr_mapping_info_.proposal_grid = g;
471     curr_mapping_info_.proposal_block = b;
472   }
473 }
474 
UpdateDynamicShapeTilingInfo()475 void DynamicTileImpl::UpdateDynamicShapeTilingInfo() {
476   if (runtime_vars_.empty()) {
477     MS_LOG(DEBUG) << "Static Tile " << kernel_name_;
478     return;
479   }
480   UpdateRuntimeVarUpperBound();
481   if (parsed_js_[kSupportInfo]["OperatorType"] == "Reduce") {
482     SolveDynamicReduction();
483   } else {
484     for (auto curr_id : curr_mapping_info_.solve_order_id) {
485       SolveDynamicTiling(curr_id);
486     }
487   }
488 
489   for (auto it : sorted_runtime_vars_) {
490     arg_size_vec_.push_back(it.second->runtime_size);
491   }
492   MS_LOG(INFO) << "Done UpdateDynamicShapTilingInfo for " << kernel_name_;
493 }
494 
GetFoldedShape(const LocVector & host_loc_vec)495 int64_t AkgKernelImplInfo::GetFoldedShape(const LocVector &host_loc_vec) {
496   auto folded_shape = 1;
497   for (auto host_loc : host_loc_vec) {
498     auto curr_shape = 1;
499     if (host_loc.first == shape_list_.size()) {
500       // pure number pair <M, number>
501       curr_shape = SizeToInt(host_loc.second);
502     } else if (shape_list_[host_loc.first].size() != 0) {
503       curr_shape = shape_list_[host_loc.first][host_loc.second];
504     }
505     if (folded_shape > INT64_MAX / curr_shape) {
506       MS_EXCEPTION(RuntimeError) << "For " << kernel_name_ << ", the product of shapes, " << folded_shape << " and "
507                                  << curr_shape << ", exceeds INT64_MAX.";
508     }
509     folded_shape *= curr_shape;
510   }
511   return folded_shape;
512 }
513 
UpdateDynamicShapeMappingInfo()514 void AkgKernelImplInfo::UpdateDynamicShapeMappingInfo() {
515   for (auto item : unknown_map_loc_) {
516     auto thread_info_id = item.first;
517     if (solved_map_loc_.count(thread_info_id)) {
518       continue;
519     }
520     if (thread_info_id >= thread_info_.size()) {
521       MS_EXCEPTION(RuntimeError) << "Unknown thread arg index should not exceed thread_info length, "
522                                  << "which is 6 (Grid.X/Y/Z, Block.X/Y/Z).";
523     }
524     auto host_loc_vec = item.second;
525     auto dim_size = GetFoldedShape(host_loc_vec);
526     auto tile_size = thread_info_[thread_info_id];
527     auto dim_size_float = static_cast<float>(dim_size);
528     auto tile_size_float = static_cast<float>(tile_size);
529     thread_info_[thread_info_id] = static_cast<uint32_t>(std::ceil(dim_size_float / tile_size_float));
530   }
531 
532   MS_LOG(DEBUG) << "For " << kernel_name_ << ",  thread_info = " << thread_info_;
533   MS_LOG(INFO) << "Done UpdateDynamicShapeMappingInfo for " << kernel_name_;
534 }
535 
GetDeviceArgSizeVec()536 void AkgKernelImplInfo::GetDeviceArgSizeVec() {
537   arg_size_vec_.clear();
538   std::vector<std::vector<int64_t>> device_shape(device_shape_list_);
539   // Update each unknown value in device_shape_list to get real device shape
540   for (auto item : device_host_shape_loc_) {
541     auto device_loc = item.first;
542     auto host_loc_vec = item.second;
543     device_shape[device_loc.first][device_loc.second] = GetFoldedShape(host_loc_vec);
544   }
545   problem_size_ = 1;
546   for (size_t i = 0; i < device_shape.size(); i++) {
547     MS_LOG(DEBUG) << "For " << kernel_name_ << ", input[" << i << "]: host_shape = " << shape_list_[i]
548                   << ", device_shape = " << device_shape[i];
549     arg_size_vec_.push_back(kRemove);  // useless in memref
550     arg_size_vec_.push_back(kKeep);    // data ptr
551     arg_size_vec_.push_back(0);        // offset
552     auto device_tensor_shape = device_shape[i];
553     int64_t tensor_size = 1;
554     for (auto item : device_tensor_shape) {
555       if (item <= 0) {
556         MS_EXCEPTION(RuntimeError) << "Shape still have negative value for kernel: " << kernel_name_
557                                    << "with host shape[" << i << "] = " << shape_list_[i] << ", device_tensor_shape["
558                                    << i << "] = " << device_tensor_shape;
559       }
560       arg_size_vec_.push_back(item);
561       tensor_size *= item;
562     }
563     problem_size_ = std::max<int64_t>(problem_size_, tensor_size);
564     vector<int64_t> strides(device_tensor_shape.size(), 1);
565     for (int j = SizeToInt(device_tensor_shape.size()) - 2; j >= 0; j--) {
566       strides[j] = strides[j + 1] * device_tensor_shape[j + 1];
567     }
568     for (auto item : strides) {
569       arg_size_vec_.push_back(item);
570     }
571   }
572   MS_LOG(DEBUG) << "For " << kernel_name_ << ", arg_size_vec = " << arg_size_vec_;
573   MS_LOG(INFO) << "Done GetDeviceArgSizeVec for " << kernel_name_;
574 }
575 
UpdateRuntimeVarUpperBound()576 void DynamicTileImpl::UpdateRuntimeVarUpperBound() {
577   // Comes from `Block/Thread: [upper_bound, prime]`
578   for (const auto &item : unknown_map_loc_) {
579     auto thread_info_id = item.first;
580     auto host_loc_vec = item.second;
581     auto tile_size = thread_info_[thread_info_id];
582     auto it = runtime_vars_.find(tile_size);
583     if (it != runtime_vars_.end()) {
584       auto dim_size = GetFoldedShape(host_loc_vec);
585       it->second->upper_bound = dim_size;
586       it->second->outer_map_id = thread_info_id;
587       auto symbol = unknown_map_symbol_[thread_info_id];
588       axis_length_left_[symbol] = dim_size;
589     }
590   }
591 
592   // Comes from `Seq: [upper_bound, prime]`
593   for (const auto &it : local_upper_bound_) {
594     auto prime = it.first;
595     if (runtime_vars_.find(prime) != runtime_vars_.end()) {
596       auto host_loc = it.second;
597       auto dim_size = shape_list_[host_loc.first][host_loc.second];
598       runtime_vars_[prime]->upper_bound = dim_size;
599       auto symbol = local_upper_bound_symbol_[prime];
600       axis_length_left_[symbol] = dim_size;
601     }
602   }
603 }
604 
UpdateMapping(int curr_id,int64_t map_size,int64_t prime)605 void DynamicTileImpl::UpdateMapping(int curr_id, int64_t map_size, int64_t prime) {
606   // skip when mark is seq.x
607   if (curr_id != -1) {
608     thread_info_[curr_id] = map_size;
609     solved_map_loc_.insert(curr_id);
610     curr_mapping_info_.UpdateCurrMapSize(curr_id, map_size);
611   }
612   if (runtime_vars_.find(prime) == runtime_vars_.end()) {
613     return;
614   }
615   runtime_vars_[prime]->runtime_size = map_size;
616   int64_t neg_prime = -prime;
617   if (runtime_vars_.find(neg_prime) != runtime_vars_.end()) {
618     runtime_vars_[neg_prime]->runtime_size = -map_size;
619   }
620 }
621 
TileSizeOpt(const RuntimeVarPtr & var,int64_t dyn_tile_size)622 int64_t DynamicTileImpl::TileSizeOpt(const RuntimeVarPtr &var, int64_t dyn_tile_size) {
623   // Currently we only optimize tile size for elementwise ops based on problem size.
624   bool map_outer_grid = var->curr_map_id >= AKG_KERNEL_MOD_BX_IDX && var->curr_map_id <= AKG_KERNEL_MOD_BZ_IDX;
625   if (map_outer_grid) {
626     auto rest_grid = std::max<int64_t>(1, curr_mapping_info_.proposal_grid / curr_mapping_info_.total_alloc_grid);
627     dyn_tile_size = std::min<int64_t>(dyn_tile_size, rest_grid);
628   } else {
629     auto rest_block = std::max<int64_t>(1, curr_mapping_info_.proposal_block / curr_mapping_info_.total_alloc_block);
630     if (var->curr_map_id == AKG_KERNEL_MOD_TY_IDX && var->upper_bound % rest_block != 0) {
631       rest_block = 1;
632     }
633     dyn_tile_size = std::min<int64_t>(dyn_tile_size, rest_block);
634   }
635   return dyn_tile_size;
636 }
637 
SolveDynamicReduction()638 void DynamicTileImpl::SolveDynamicReduction() {
639   int proper_block = -1, proper_thread = -1, proper_seq = -1;
640   int total_red_size =
641     std::accumulate(runtime_threads_order_.begin(), runtime_threads_order_.end(), 1,
642                     [&](int total, int prime) { return total * axis_length_left_[product_var_[prime]]; });
643 
644   if (dyn_algorithm_ == AKG_DYN_ALGO_REDUCE_X) {
645     GetProperReduceXConfig(total_red_size, enable_atomic_, &proper_block, &proper_thread, &proper_seq);
646   } else if (dyn_algorithm_ == AKG_DYN_ALGO_REDUCE_Y) {
647     GetProperReduceYConfig(total_red_size, enable_atomic_, &proper_block, &proper_seq);
648   }
649 
650   if (dyn_algorithm_ != AKG_DYN_ALGO_REDUCE_X) {
651     proper_thread = kNum32;
652   }
653 
654   proper_thread = (proper_thread - 1) / curr_mapping_info_.total_alloc_block + 1;  // remove used
655   for (const auto &p : template_tiling_order_) {
656     int prime = p.first;
657     int mark = p.second;
658     int upper_bound = -1;
659     auto symbol = product_var_[prime];
660     auto current_length = axis_length_left_[symbol];
661     if (AKG_DYN_MARK_THREAD_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_THREAD_UPPER_BOUND) {
662       upper_bound = std::min<int>(current_length, (MAX_THREAD_NUM / curr_mapping_info_.total_alloc_block));
663     } else {
664       upper_bound = current_length;
665     }
666     int tile_size = -1;
667     if (mark == AKG_DYN_MARK_PRODUCT) {
668       tile_size =
669         runtime_vars_[related_values_[prime][0]]->runtime_size * runtime_vars_[related_values_[prime][1]]->runtime_size;
670     } else {
671       GetReduceTileSize(mark, upper_bound, proper_thread, proper_seq, &tile_size);
672       if (AKG_DYN_MARK_SEQ_LOWER_BOUND <= mark && mark < AKG_DYN_MARK_SEQ_UPPER_BOUND) {
673         proper_seq = proper_seq / tile_size;
674       }
675     }
676     // update mapping
677     if (prime_to_mapping_idx_[prime] != -1) {
678       // scenario 1: BlockIdx.x = prime
679       auto curr_idx = prime_to_mapping_idx_[prime];
680       thread_info_[curr_idx] = tile_size;
681       solved_map_loc_.insert(curr_idx);
682     } else if (prime_to_mapping_dividend_[prime] != -1) {
683       // scenario 2: BlockIdx.x = symbol / prime
684       // NOTE: since we know thread_info_'s format here, we only use tile_size
685       // to represent both divier and dividend. update var name later.
686       auto curr_idx = prime_to_mapping_dividend_[prime];
687       thread_info_[curr_idx] = tile_size;
688     }
689     if (runtime_vars_.find(prime) != runtime_vars_.end()) {
690       runtime_vars_[prime]->runtime_size = tile_size;
691       int64_t neg_prime = -prime;
692       if (runtime_vars_.find(neg_prime) != runtime_vars_.end()) {
693         runtime_vars_[neg_prime]->runtime_size = -tile_size;
694       }
695     }
696     axis_length_left_[symbol] = (current_length - 1) / tile_size + 1;
697   }
698 }
699 
SolveDynamicTiling(size_t curr_id)700 void DynamicTileImpl::SolveDynamicTiling(size_t curr_id) {
701   auto prime = thread_info_[curr_id];
702   auto it = runtime_vars_.find(prime);
703   if (it == runtime_vars_.end()) {
704     return;
705   }
706   auto var = it->second;
707   if (var->curr_map_id != static_cast<int>(curr_id)) {
708     if (var->outer_map_id != static_cast<int>(curr_id)) {
709       MS_EXCEPTION(RuntimeError) << "Unknown var: " << var->ToString() << "; Cannot map currId: " << curr_id;
710     } else {
711       // In this branch, dividend is stored in thread_info_ and dividend equals to prime number
712       // means that the mapping of `curr_id` equals to the upper bound ceildiv the `runtime_size`.
713       auto dividend = static_cast<uint32_t>((var->upper_bound - 1) / var->runtime_size) + 1;
714       UpdateMapping(curr_id, dividend, 0);
715     }
716     return;
717   }
718   auto map_limit = curr_mapping_info_.GetMapLimit(curr_id, device_target_);
719   auto upper_bound = std::min<uint32_t>(map_limit, var->upper_bound);
720   if (upper_bound < 1) {
721     MS_EXCEPTION(RuntimeError) << " Invalid upper_bound of runtime var : " << var->ToString();
722   }
723 
724   // Init dynamic tile size to the upper bound
725   int64_t dyn_tile_size = upper_bound;
726 
727   // Add dynamic tiling strategy here
728   auto warp_num = std::max<int64_t>(1, var->upper_bound / WARP_SIZE);
729   bool map_outer_block = var->outer_map_id >= AKG_KERNEL_MOD_BX_IDX && var->outer_map_id <= AKG_KERNEL_MOD_BZ_IDX;
730   if (var->curr_map_id == AKG_KERNEL_MOD_TX_IDX && var->upper_bound % WARP_SIZE != 0) {
731     dyn_tile_size = WARP_SIZE * warp_num;
732   } else if (var->curr_map_id == AKG_KERNEL_MOD_TY_IDX && var->upper_bound % WARP_SIZE != 0) {
733     if (map_outer_block && curr_mapping_info_.total_alloc_grid * WARP_SIZE < ELEM_BEST_GRID_SIZE) {
734       dyn_tile_size = 1;
735     } else {
736       dyn_tile_size = (WARP_SIZE / WARP_ALLOC_GRAN) * warp_num;
737     }
738   }
739   dyn_tile_size = std::min<int64_t>(dyn_tile_size, upper_bound);
740   dyn_tile_size = TileSizeOpt(var, dyn_tile_size);
741   UpdateMapping(curr_id, dyn_tile_size, prime);
742 }
743 }  // namespace kernel
744 }  // namespace mindspore
745