• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/oplib/oplib.h"
18 #include <memory>
19 #include <map>
20 #include <fstream>
21 #include "utils/log_adapter.h"
22 #include "utils/overload.h"
23 #include "utils/ms_context.h"
24 
25 namespace mindspore {
26 namespace kernel {
27 constexpr auto kImplyType = "imply_type";
28 constexpr auto kOpName = "op_name";
29 constexpr auto kFusionType = "fusion_type";
30 constexpr auto kAsyncFlag = "async_flag";
31 constexpr auto kBinfileName = "binfile_name";
32 constexpr auto kComputeCost = "compute_cost";
33 constexpr auto kKernelName = "kernel_name";
34 constexpr auto kPartialFlag = "partial_flag";
35 constexpr auto kReshapeType = "reshape_type";
36 constexpr auto kValueDepend = "value_depend";
37 constexpr auto kOpPattern = "op_pattern";
38 constexpr auto kIsDynamicFormat = "is_dynamic_format";
39 constexpr auto kDynamicFormat = "dynamicFormat";
40 constexpr auto kFormatAgnostic = "formatAgnostic";
41 constexpr auto kNeedCheckSupported = "need_check_supported";
42 constexpr auto kBroadcast = "broadcast";
43 constexpr auto kReduce = "reduce";
44 constexpr auto kDynamicShape = "dynamic_shape";
45 constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
46 constexpr auto kDtypeFormat = "dtype_format";
47 constexpr auto kAttr = "attr";
48 constexpr auto kIputs = "inputs";
49 constexpr auto kOutputs = "outputs";
50 constexpr auto kAiCPU = "AiCPU";
51 constexpr auto kAiCore = "AiCore";
52 constexpr auto kCUDA = "CUDA";
53 constexpr auto kTbe = "TBE";
54 constexpr auto kAkg = "AKG";
55 constexpr auto kCpu = "CPU";
56 constexpr auto kName = "name";
57 constexpr auto kParamType = "param_type";
58 constexpr auto kDtype = "dtype";
59 constexpr auto kType = "type";
60 constexpr auto kValue = "value";
61 constexpr auto kDefaultValue = "default_value";
62 constexpr auto kIndex = "index";
63 constexpr auto kFormat = "format";
64 constexpr auto kNeedCompile = "need_compile";
65 constexpr auto kShape = "shape";
66 constexpr auto kProcessor = "processor";
67 std::multimap<std::string, std::shared_ptr<OpInfo>> OpLib::op_info_;
68 
ImplTypeToStr(OpImplyType impl_type)69 static std::string ImplTypeToStr(OpImplyType impl_type) {
70   switch (impl_type) {
71     case kTBE:
72       return kTbe;
73     case kAKG:
74       return kAkg;
75     case kAICPU:
76       return kAiCPU;
77     default:
78       return "unknown";
79   }
80 }
RegOp(const std::string & json_string,const std::string & impl_path)81 bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
82   bool ret = false;
83   try {
84     auto op_json = nlohmann::json::parse(json_string);
85     std::string imply_type_string = op_json.at(kImplyType);
86     std::string op_name = op_json.at(kOpName);
87     if (imply_type_string == kTbe) {
88       OpImplyType imply_type = kTBE;
89       ret = DecodeOpInfo(op_json, imply_type, impl_path);
90     } else if (imply_type_string == kAkg) {
91       OpImplyType imply_type = kAKG;
92       ret = DecodeOpInfo(op_json, imply_type, impl_path);
93     } else if (imply_type_string == kAiCPU) {
94       OpImplyType imply_type = kAICPU;
95       ret = DecodeOpInfo(op_json, imply_type, impl_path);
96     } else if (imply_type_string == kCpu) {
97       OpImplyType imply_type = kCPU;
98       ret = DecodeOpInfo(op_json, imply_type, impl_path);
99     } else {
100       MS_LOG(ERROR) << "Not support imply_type";
101     }
102     if (!ret) {
103       MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string;
104     }
105   } catch (const std::exception &e) {
106     MS_LOG(ERROR) << "get op json elements failed: " << e.what();
107   }
108   return ret;
109 }
110 
DecodeTBESpecificInfo(const nlohmann::json & obj,const std::shared_ptr<OpInfo> & op_info)111 void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
112   const std::map<std::string, kernel::OpPattern> kOpPatternMap = {
113     {kFormatAgnostic, kFormatAgnosticPattern}, {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}};
114   MS_EXCEPTION_IF_NULL(op_info);
115   op_info->set_async_flag(obj.at(kAsyncFlag));
116   op_info->set_binfile_name(obj.at(kBinfileName));
117   op_info->set_compute_cost(obj.at(kComputeCost));
118   op_info->set_kernel_name(obj.at(kKernelName));
119   op_info->set_partial_flag(obj.at(kPartialFlag));
120   op_info->set_need_check_supported(obj.at(kNeedCheckSupported));
121 
122   if (obj.find(kDynamicShape) != obj.end()) {
123     op_info->set_dynamic_shape(obj.at(kDynamicShape));
124   }
125 
126   if (obj.find(kDynamicCompileStatic) != obj.end()) {
127     op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
128   }
129 
130   if (obj.find(kIsDynamicFormat) != obj.end()) {
131     op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
132   }
133 
134   if (obj.find(kOpPattern) != obj.end()) {
135     std::string op_pattern = obj.at(kOpPattern);
136     auto find_iter = kOpPatternMap.find(op_pattern);
137     if (find_iter == kOpPatternMap.end()) {
138       if (!op_pattern.empty()) {
139         MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
140       }
141       op_info->set_op_pattern(kCommonPattern);
142     } else {
143       op_info->set_op_pattern(find_iter->second);
144     }
145   }
146 }
147 
DecodeAKGSpecificInfo(const nlohmann::json & obj,const std::shared_ptr<OpInfo> & op_info)148 void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
149   MS_EXCEPTION_IF_NULL(op_info);
150   op_info->set_processor(obj.at(kProcessor));
151 }
152 
RegOpFromLocalInfo()153 bool OpLib::RegOpFromLocalInfo() {
154   static bool has_load = false;
155   if (has_load) {
156     return true;
157   }
158   MS_LOG(INFO) << "Start";
159   has_load = true;
160   std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH");
161   if (dir.empty()) {
162     MS_LOG(INFO) << "MindSpore op info path does not been set. use op info from python pass.";
163     return true;
164   }
165   char real_path[PATH_MAX] = {0};
166   if (dir.size() >= PATH_MAX) {
167     MS_LOG(ERROR) << "Op info path is invalid: " << dir;
168     return false;
169   }
170 #if defined(_WIN32) || defined(_WIN64)
171   if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) {
172     MS_LOG(ERROR) << "Op info path is invalid: " << dir;
173     return false;
174   }
175 #else
176   if (realpath(common::SafeCStr(dir), real_path) == nullptr) {
177     MS_LOG(ERROR) << "Op info path is invalid: " << dir;
178     return false;
179   }
180   if (strlen(real_path) >= PATH_MAX) {
181     MS_LOG(ERROR) << "Op info path is invalid, the absolute path length is greater than PATH_MAX";
182     return false;
183   }
184 #endif
185   MS_LOG(INFO) << "Start to read op info from local file.";
186   std::ifstream file(real_path);
187   if (!file.is_open()) {
188     MS_LOG(ERROR) << "Find op info file failed.";
189     return false;
190   }
191   std::string line;
192   while (getline(file, line)) {
193     if (!line.empty()) {
194       (void)OpLib::RegOp(line, "");
195     }
196   }
197   file.close();
198   MS_LOG(INFO) << "End";
199   return true;
200 }
201 
DecodeOpInfo(const nlohmann::json & obj,const mindspore::kernel::OpImplyType imply_type,const std::string & impl_path)202 bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
203                          const std::string &impl_path) {
204   std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
205   MS_EXCEPTION_IF_NULL(op_info);
206   op_info->set_op_name(obj.at(kOpName));
207   op_info->set_impl_path(impl_path);
208   op_info->set_imply_type(imply_type);
209   op_info->set_fusion_type(obj.at(kFusionType));
210   if (imply_type == kTBE) {
211     DecodeTBESpecificInfo(obj, op_info);
212   } else if (imply_type == kAKG) {
213     DecodeAKGSpecificInfo(obj, op_info);
214   }
215   auto attrs = obj.at(kAttr);
216   for (const auto &attr : attrs) {
217     if (!DecodeAttr(attr, imply_type, op_info)) {
218       MS_LOG(ERROR) << "DecodeAttr Failed";
219       return false;
220     }
221   }
222   nlohmann::json dtype_format;
223   if (obj.find(kDtypeFormat) != obj.end()) {
224     dtype_format = obj.at(kDtypeFormat);
225   }
226   auto inputs = obj.at(kIputs);
227   for (const auto &input : inputs) {
228     if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
229       MS_LOG(ERROR) << "DecodeInputOutput Failed";
230       return false;
231     }
232   }
233   auto outputs = obj.at(kOutputs);
234   for (const auto &output : outputs) {
235     if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
236       MS_LOG(ERROR) << "DecodeInputOutput Failed";
237       return false;
238     }
239   }
240   if (CheckRepetition(op_info)) {
241     MS_LOG(WARNING) << "This op info has been already registered. op name: " << op_info->op_name()
242                     << ", impl type: " << ImplTypeToStr(op_info->imply_type())
243                     << ", impl path: " << op_info->impl_path();
244     return true;
245   }
246   if (!GetRefInfo(op_info)) {
247     MS_LOG(ERROR) << "GetRefInfo Failed";
248     return false;
249   }
250   op_info_.emplace(op_info->op_name(), op_info);
251   return true;
252 }
253 
DecodeAttr(const nlohmann::json & obj,const OpImplyType imply_type,const std::shared_ptr<OpInfo> & op_info)254 bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
255                        const std::shared_ptr<OpInfo> &op_info) {
256   MS_EXCEPTION_IF_NULL(op_info);
257   bool ret = true;
258   try {
259     std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>();
260     MS_EXCEPTION_IF_NULL(op_attr);
261     op_attr->set_name(obj.at(kName));
262     if (imply_type != kAICPU) {
263       op_attr->set_param_type(obj.at(kParamType));
264     }
265     op_attr->set_type(obj.at(kType));
266     if (imply_type == kTBE) {
267       op_attr->set_value(obj.at(kValue));
268     }
269     if (obj.find(kDefaultValue) != obj.end()) {
270       op_attr->set_default_value(obj.at(kDefaultValue));
271     }
272     op_info->add_attrs_ptr(op_attr);
273   } catch (const std::exception &e) {
274     MS_LOG(ERROR) << "DecodeAttr failed:" << e.what();
275     ret = false;
276   }
277   return ret;
278 }
279 
DecodeDtypeFormat(const nlohmann::json & dtype_format,const std::shared_ptr<OpIOInfo> & op_io,size_t index)280 bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
281                               size_t index) {
282   MS_EXCEPTION_IF_NULL(op_io);
283   bool ret = true;
284   try {
285     std::vector<std::string> dtype;
286     std::vector<std::string> format;
287     for (const auto &it : dtype_format) {
288       dtype.emplace_back(it[index][0]);
289       format.emplace_back(it[index][1]);
290     }
291     op_io->set_dtypes(dtype);
292     op_io->set_formats(format);
293   } catch (const std::exception &e) {
294     MS_LOG(ERROR) << "DecodeDtypeFormat failed" << e.what();
295     ret = false;
296   }
297   return ret;
298 }
299 
DecodeInputOutput(const nlohmann::json & obj,const OpImplyType imply_type,const OpIOType io_type,const std::shared_ptr<OpInfo> & op_info,const nlohmann::json & dtype_format)300 bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
301                               const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
302   MS_EXCEPTION_IF_NULL(op_info);
303   bool ret = true;
304   try {
305     std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
306     MS_EXCEPTION_IF_NULL(op_io);
307     op_io->set_index(obj.at(kIndex));
308     op_io->set_name(obj.at(kName));
309     if (!dtype_format.empty()) {
310       if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
311         MS_LOG(ERROR) << "Decode dtype format failed";
312         return false;
313       }
314     } else {
315       op_io->set_dtypes(obj.at(kDtype));
316       op_io->set_formats(obj.at(kFormat));
317     }
318     if (op_io->dtypes().size() != op_io->formats().size()) {
319       MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes()
320                     << " is not equal to format size: " << op_io->formats();
321       return false;
322     }
323     if (obj.find(kParamType) != obj.end()) {
324       op_io->set_param_type(obj.at(kParamType));
325     }
326     if (imply_type == kTBE) {
327       if (obj.find(kNeedCompile) != obj.end()) {
328         op_io->set_need_compile(obj.at(kNeedCompile));
329       }
330       if (obj.find(kShape) != obj.end()) {
331         op_io->set_shape(obj.at(kShape));
332       }
333       if (obj.find(kReshapeType) != obj.end()) {
334         op_io->set_reshape_type(obj.at(kReshapeType));
335       }
336       if (obj.find(kValueDepend) != obj.end()) {
337         op_io->set_value_depend(obj.at(kValueDepend));
338       }
339     }
340 
341     if (io_type == kInput) {
342       op_info->add_inputs_ptr(op_io);
343     } else if (io_type == kOutput) {
344       op_info->add_outputs_ptr(op_io);
345     }
346   } catch (const std::exception &e) {
347     MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what();
348     ret = false;
349   }
350   return ret;
351 }
352 
FindOp(const std::string & op_name,OpImplyType imply_type,bool is_dynamic_shape)353 std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type, bool is_dynamic_shape) {
354   if (!OpLib::RegOpFromLocalInfo()) {
355     MS_LOG(INFO) << "Warning reg local op info failed.";
356   }
357   auto context = MsContext::GetInstance();
358   MS_EXCEPTION_IF_NULL(context);
359   bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
360   if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
361     MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
362                   << ", current op num: " << op_info_.size();
363     return nullptr;
364   }
365   std::string target_processor = is_gpu ? kCUDA : kAiCore;
366   for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
367     auto &op_info = (*iter).second;
368     MS_EXCEPTION_IF_NULL(op_info);
369     if (op_info->imply_type() != imply_type) {
370       continue;
371     }
372     if (imply_type == kAKG && op_info->processor() != target_processor) {
373       continue;
374     }
375     if (is_dynamic_shape && !op_info->dynamic_shape()) {
376       continue;
377     }
378     return op_info;
379   }
380   MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
381                << ", current op num: " << op_info_.size() << " is_dynamic_shape:" << is_dynamic_shape;
382   return nullptr;
383 }
384 
GetRefInfo(const std::shared_ptr<OpInfo> & op_info)385 bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
386   MS_EXCEPTION_IF_NULL(op_info);
387   const auto &output_infos = op_info->outputs_ptr();
388   const auto &input_infos = op_info->inputs_ptr();
389   for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
390     MS_EXCEPTION_IF_NULL(output_infos[out_index]);
391     const auto &out_name = output_infos[out_index]->name();
392     for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
393       MS_EXCEPTION_IF_NULL(input_infos[in_index]);
394       const auto &in_name = input_infos[in_index]->name();
395       if (out_name == in_name) {
396         if (op_info->has_ref_index(out_index)) {
397           MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info";
398           return false;
399         }
400         op_info->add_ref_pair(out_index, in_index);
401       }
402     }
403   }
404   return true;
405 }
406 
CheckRepetition(const std::shared_ptr<OpInfo> & op_info)407 bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
408   MS_EXCEPTION_IF_NULL(op_info);
409   for (auto [iter, end] = op_info_.equal_range(op_info->op_name()); iter != end; ++iter) {
410     auto &exist_op_info = (*iter).second;
411     MS_EXCEPTION_IF_NULL(exist_op_info);
412     if (exist_op_info->equals_to(op_info)) {
413       return true;
414     }
415   }
416   return false;
417 }
418 }  // namespace kernel
419 }  // namespace mindspore
420