• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_DYNAMIC_AKG_GPU_UTILS_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_DYNAMIC_AKG_GPU_UTILS_H_
19 #include <cuda.h>
20 #include <string>
21 #include <vector>
22 #include <unordered_set>
23 #include <unordered_map>
24 #include <map>
25 #include <memory>
26 #include <utility>
27 #include "kernel/kernel.h"
28 #include "plugin/device/gpu/kernel/gpu_kernel_mod.h"
29 #include "kernel/common_utils.h"
30 
31 namespace mindspore {
32 namespace kernel {
33 struct PairHash {
34   template <typename T>
operatorPairHash35   size_t operator()(const std::pair<T, T> &p) const {
36     auto h1 = std::hash<T>{}(p.first);
37     auto h2 = std::hash<T>{}(p.second);
38     return h1 ^ h2;
39   }
40 };
41 enum GpuMemScope {
42   // global
43   MEM_SCOPE_GM = 0,
44   // gpu
45   MEM_SCOPE_SHARED,
46   MEM_SCOPE_LOCAL,
47   // end
48   MEM_SCOPE_BULK,
49 };
50 class GpuInfo {
51  public:
52   GpuInfo(const GpuInfo &) = delete;
53   GpuInfo &operator=(const GpuInfo &) = delete;
~GpuInfo()54   ~GpuInfo() {}
GetInstance(const std::string & device_type)55   static GpuInfo &GetInstance(const std::string &device_type) {
56     static GpuInfo hardware_info(device_type);
57     return hardware_info;
58   }
59 
GetMemoryLimitInScope(int scope_idx)60   int64_t GetMemoryLimitInScope(int scope_idx) {
61     if (scope_idx > MEM_SCOPE_BULK) {
62       MS_EXCEPTION(RuntimeError) << "scope_idx should be less than " << MEM_SCOPE_BULK << ", but got " << scope_idx
63                                  << "\n";
64       return 0;
65     }
66     return gpuMemLimit[scope_idx];
67   }
68 
GetWarpSizes()69   int GetWarpSizes() { return warpSize; }
GetNumSm()70   int GetNumSm() { return numSm; }
GetActiveBlocksPerSm()71   std::pair<int, int> GetActiveBlocksPerSm() { return activeBlocksPerSm; }
GetThreadCoef()72   std::pair<int, int> GetThreadCoef() { return threadCoef; }
GetMinElemForIoBound()73   int GetMinElemForIoBound() { return minElemForIoBound; }
GetMaxElemForIoBound()74   int GetMaxElemForIoBound() { return maxElemForIoBound; }
GetTotalAvailableBlocks()75   int GetTotalAvailableBlocks() { return totalAvailableBlocks; }
GetMaxGrids()76   std::vector<int64_t> GetMaxGrids() { return {maxGridX, maxGridYZ, maxGridYZ}; }
GetMaxBlocks()77   std::vector<int64_t> GetMaxBlocks() { return {maxBlockXY, maxBlockXY, maxBlockZ}; }
78 
79  private:
GpuInfo(const std::string & device_type)80   explicit GpuInfo(const std::string &device_type) {
81     InitGpuMemoryLimit(device_type);
82     InitGpuComputeCapability(device_type);
83   }
84   int64_t gpuMemLimit[MEM_SCOPE_BULK]{0};
85   int numSm{80};
86   int warpSize{32};
87   int minElemForIoBound{2};
88   int maxElemForIoBound{32};
89   int totalAvailableBlocks{1024};
90   std::pair<int, int> threadCoef{8, 16};
91   std::pair<int, int> activeBlocksPerSm{5, 6};
92   int64_t maxGridX = 2147483647;
93   int64_t maxGridYZ = 65535;
94   int64_t maxBlockXY = 1024;
95   int64_t maxBlockZ = 64;
96 
97   void InitGpuMemoryLimit(const std::string &device_type);
98   void InitGpuComputeCapability(const std::string &device_type);
99 };
100 
101 struct MappingInfo {
102   uint32_t total_alloc_grid{1};
103   uint32_t total_alloc_block{1};
104   uint32_t curr_grid[3]{1, 1, 1};
105   uint32_t curr_block[3]{1, 1, 1};
106   int64_t proposal_grid{1};
107   int64_t proposal_block{1};
108   std::vector<size_t> solve_order_id;
109   uint32_t GetMapLimit(size_t id, const std::string &device_target);
110   void UpdateCurrMapSize(size_t id, uint32_t map_size);
ToStringMappingInfo111   std::string ToString() {
112     std::string res;
113     res += "total_alloc_grid: [";
114     for (auto g : curr_grid) {
115       res += std::to_string(g) + ", ";
116     }
117     res += "] = " + std::to_string(total_alloc_grid) + "; ";
118     res += "total_alloc_block: [";
119     for (auto b : curr_block) {
120       res += std::to_string(b) + ", ";
121     }
122     res += "] = " + std::to_string(total_alloc_block) + "\n";
123     return res;
124   }
125 };
126 
127 struct RuntimeVar {
128  public:
129   // Init from json
130   int64_t prime;                   // prime is like a unique id for this var to speedup lower in pipeline
131   int argIndex{-1};                // index in the func argument
132   std::string mapping{"Default"};  // used for GPU mapping, can be chosen from [Grid, Block, Seq]
133   std::string mapDim{""};          // used for GPU mapping, can be chosen from [x, y, z]
134   std::string expr{""};            // used to solve dynamic tiling
135   int mark{999};                   // used to solve dynamic tiling for specific algorithms, default is unknown{999}
136 
137   // Init in resize
138   int64_t upper_bound{-1};
139   int outer_map_id{-1};
140   int curr_map_id{-1};
141   int64_t runtime_size{-1};
142 
ArgIndexKeyRuntimeVar143   std::string ArgIndexKey() { return "argIndex"; }
ExprKeyRuntimeVar144   std::string ExprKey() { return "expr"; }
MarkKeyRuntimeVar145   std::string MarkKey() { return "mark"; }
MapDimKeyRuntimeVar146   std::string MapDimKey() { return "mapDim"; }
MappingKeyRuntimeVar147   std::string MappingKey() { return "mapping"; }
PrimeKeyRuntimeVar148   std::string PrimeKey() { return "prime"; }
ToStringRuntimeVar149   std::string ToString() {
150     std::string res = "[RuntimeVar " + std::to_string(prime) + "]";
151     res += "  -> " + mapping + "." + mapDim + " at " + std::to_string(argIndex) + " input\n";
152     res += "  -> expr: " + expr + "\n";
153     res += "  -> mark: " + std::to_string(mark) + "\n";
154     res += "  -> upper bound " + std::to_string(upper_bound) + "; curr_map_id " + std::to_string(curr_map_id) +
155            "; outer_map_id " + std::to_string(outer_map_id) + "\n";
156     res += "  -> runtime_size " + std::to_string(runtime_size) + "\n";
157     return res;
158   }
159 
160   static std::unordered_map<std::string, int> mark_table_;
161 };
162 using RuntimeVarPtr = std::shared_ptr<RuntimeVar>;
163 using RuntimeVarsMap = std::map<int, RuntimeVarPtr>;
164 
165 enum AkgKernelImplType {
166   DEFAULT = 0,
167   STATIC_TILE,
168   DYNYAMIC_TILE,
169 };
170 
171 using LocVector = std::vector<std::pair<size_t, size_t>>;
172 class AkgKernelImplInfo {
173  public:
174   AkgKernelImplInfo(const std::string &kernel_name, nlohmann::json json);
175   virtual ~AkgKernelImplInfo() = default;
176 
Init()177   virtual void Init() {}
Resize()178   virtual void Resize() {}
179 
180   // update each time
181   std::vector<uint32_t> thread_info_;
182   std::unordered_map<size_t, LocVector> unknown_map_loc_;
183   std::unordered_set<size_t> solved_map_loc_;
184   std::vector<int64_t> arg_size_vec_;
185   MappingInfo curr_mapping_info_;
186   std::vector<std::vector<int64_t>> shape_list_;
187   std::vector<std::vector<int64_t>> device_shape_list_;
188   int64_t problem_size_ = 1;
189 
190   // no change
191   std::string kernel_name_;
192   nlohmann::json parsed_js_;
193   AkgKernelImplType kernel_type_{AkgKernelImplType::DEFAULT};
194   std::vector<uint32_t> init_mapping_;
195   MappingInfo init_mapping_info_;
196   RuntimeVarsMap runtime_vars_;
197   std::vector<std::pair<uint32_t, RuntimeVarPtr>> sorted_runtime_vars_;
198   std::unordered_map<uint32_t, std::pair<size_t, size_t>> local_upper_bound_;
199   size_t max_shape_rank_ = 0;
200   std::string device_target_;
201   std::unordered_map<std::pair<size_t, size_t>, LocVector, PairHash> device_host_shape_loc_;
202   std::unordered_map<std::string, std::pair<size_t, size_t>> host_loc_map_;
203   std::unordered_map<int64_t, std::string> product_var_;   // map: prime -> symbol like `s0`
204   std::unordered_map<std::string, int> axis_length_left_;  // update map: symbol axis -> total length of one axis
205   std::unordered_map<int, std::vector<int>> related_values_;
206   std::unordered_map<size_t, std::string> unknown_map_symbol_;
207   std::unordered_map<uint32_t, std::string> local_upper_bound_symbol_;
208   std::vector<std::string> map_arg_list_;
209 
210   int static_reduce_length_{1};
211   std::vector<int> runtime_threads_order_;
212   std::vector<std::pair<int, int>> template_tiling_order_;  // pair includes prime & mark
213   std::unordered_map<int, int> prime_to_mapping_idx_;
214   std::unordered_map<int, int> prime_to_mapping_dividend_;
215   bool enable_atomic_{false};
216   int dyn_algorithm_{0};
217   static std::unordered_map<std::string, int> algo_to_int_;
218 
219   void preprocessDynamicReduceTiling();
220   LocVector GetHostLocationVec(std::string symbol_expr, const size_t pure_num_flag);
221   void InitJsonShapeInformation();
222   void InitJsonMappingInformation();
CheckJsonValueFormat(const std::string & key)223   bool CheckJsonValueFormat(const std::string &key) {
224     constexpr size_t valid_value_size = 2;
225     auto value = parsed_js_[key];
226     return (value.is_array() && value.size() == valid_value_size && value[0].is_string() && value[1].is_number());
227   }
228   void GetDeviceArgSizeVec();
229   void InitBeforeMapping();
230   int64_t GetFoldedShape(const LocVector &host_loc_vec);
231   void UpdateDynamicShapeMappingInfo();
232 };
233 using AkgKernelImplInfoPtr = std::shared_ptr<AkgKernelImplInfo>;
234 
235 class DynamicTileImpl : public AkgKernelImplInfo {
236  public:
DynamicTileImpl(const std::string & kernel_name,nlohmann::json json)237   DynamicTileImpl(const std::string &kernel_name, nlohmann::json json) : AkgKernelImplInfo(kernel_name, json) {
238     this->kernel_type_ = AkgKernelImplType::DYNYAMIC_TILE;
239   }
240   virtual ~DynamicTileImpl() = default;
Init()241   void Init() override {
242     this->InitBeforeMapping();
243     this->UpdateRuntimeVarUpperBound();
244   }
Resize()245   void Resize() override {
246     this->InitBeforeMapping();
247     this->GetDeviceArgSizeVec();
248     this->UpdateDynamicShapeTilingInfo();
249     this->UpdateDynamicShapeMappingInfo();
250   }
251 
252  private:
253   void UpdateDynamicShapeTilingInfo();
254   void UpdateRuntimeVarUpperBound();
255   void SolveDynamicReduction();
256   void SolveDynamicTiling(size_t curr_id);
257   int64_t TileSizeOpt(const RuntimeVarPtr &var, int64_t dyn_tile_size);
258   void UpdateMapping(int curr_id, int64_t map_size, int64_t prime);
259 };
260 
261 class StaticTileImpl : public AkgKernelImplInfo {
262  public:
StaticTileImpl(const std::string & kernel_name,nlohmann::json json)263   StaticTileImpl(const std::string &kernel_name, nlohmann::json json) : AkgKernelImplInfo(kernel_name, json) {
264     this->kernel_name_ = kernel_name;
265     this->parsed_js_ = json;
266     this->kernel_type_ = AkgKernelImplType::STATIC_TILE;
267   }
268   virtual ~StaticTileImpl() = default;
Init()269   void Init() override {
270     this->InitBeforeMapping();
271     this->UpdateDynamicShapeMappingInfo();
272   }
Resize()273   void Resize() override {
274     this->InitBeforeMapping();
275     this->GetDeviceArgSizeVec();
276     this->UpdateDynamicShapeMappingInfo();
277   }
278 };
279 }  // namespace kernel
280 }  // namespace mindspore
281 
282 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_DYNAMIC_AKG_GPU_UTILS_H_
283