• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/oplib/oplib.h"
25 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
26 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
27 #include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
28 #include "backend/kernel_compiler/tbe/ascend_kernel_compile.h"
29 #include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
30 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
31 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
32 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
33 #include "backend/optimizer/common/helper.h"
34 #include "backend/session/anf_runtime_algorithm.h"
35 #include "backend/session/kernel_build_client.h"
36 #include "nlohmann/json.hpp"
37 #include "utils/convert_utils_base.h"
38 #include "utils/json_operation_utils.h"
39 
40 namespace mindspore::kernel {
41 constexpr auto kName = "name";
42 constexpr auto kDtype = "dtype";
43 constexpr auto kFormat = "format";
44 constexpr auto kPrefixInput = "input";
45 constexpr auto kPrefixOutput = "output";
46 constexpr char kParamTypeDynamic[] = "dynamic";
47 constexpr char kParamTypeRequre[] = "required";
48 constexpr char kParamTypeOptional[] = "optional";
TbeMetadataInfo(const CNodePtr & kernel_node,std::vector<std::shared_ptr<KernelBuildInfo>> * kernel_info_list)49 void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
50   auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
51   tbe_selecter.TbeMetadataInfoEx();
52 }
53 
TbeKernelSelect(CNodePtr kernel_node,std::vector<std::shared_ptr<KernelBuildInfo>> * kernel_info_list)54 TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list)
55     : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {}
56 
TbeMetadataInfoEx()57 void TbeKernelSelect::TbeMetadataInfoEx() {
58   MS_EXCEPTION_IF_NULL(cnode_ptr_);
59   MS_EXCEPTION_IF_NULL(kernel_info_list_);
60   node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
61   full_name_ = cnode_ptr_->fullname_with_scope();
62 
63   auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
64   if (!op_info_ptr) {
65     return;
66   }
67   if (!TbePropertyChecker::CheckTbeProperties(cnode_ptr_)) {
68     MS_LOG(INFO) << "Warning: node(" << full_name_ << ") is not supported by tbe ai_core.";
69     return;
70   }
71 
72   if (op_info_ptr->is_dynamic_format()) {
73     GetDynamicFormatPatternKernelInfo(*op_info_ptr);
74   } else {
75     OpPattern pattern = op_info_ptr->op_pattern();
76     if (pattern == kCommonPattern) {
77       GetCommonPatternKernelInfo(*op_info_ptr);
78     } else if (pattern == kFormatAgnosticPattern) {
79       GetAgnosticPatternKernelInfo(*op_info_ptr);
80     } else if (pattern == kBroadcastPattern) {
81       GetBroadcastPatternKernelInfo(*op_info_ptr);
82     } else if (pattern == kReducePattern) {
83       GetReducePatternKernelInfo(*op_info_ptr);
84     } else {
85       MS_LOG(INFO) << "Warning: op pattern is invailed.";
86     }
87   }
88   // check support
89   FilterInVaildKernelInfo(*op_info_ptr);
90 }
91 
GetCommonPatternKernelInfo(const OpInfo & op_info)92 void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
93   auto dyn_input_sizes = GetNodeDynamicInputs();
94   // get real input/output num
95   size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
96   const auto inputs_info = op_info.inputs_ptr();
97   size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
98   const auto outputs_info = op_info.outputs_ptr();
99   if (inputs_info.empty() && outputs_info.empty()) {
100     MS_LOG(EXCEPTION) << AnfAlgo::GetCNodeName(cnode_ptr_) << "'s op info input & output is null, please check.";
101   }
102   // create kernel build info from opinfo
103   size_t kernel_build_info_num =
104     inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size();
105   for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) {
106     auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
107     SetTbeBuildCommonInfo(op_info, &builder);
108     std::vector<std::string> inputs_format;
109     std::vector<TypeId> inputs_device_type;
110     std::vector<std::string> inputs_reshape_type;
111     std::vector<std::string> inputs_value_depend;
112     // input
113     if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
114                         &inputs_format, &inputs_device_type, &inputs_reshape_type, &inputs_value_depend)) {
115       break;
116     }
117     builder.SetInputsDeviceType(inputs_device_type);
118     builder.SetInputsFormat(inputs_format);
119     builder.SetInputsReshapeType(inputs_reshape_type);
120     builder.SetInputsValueDepend(inputs_value_depend);
121     // output
122     std::vector<std::string> outputs_format;
123     std::vector<TypeId> outputs_device_type;
124     std::vector<std::string> outputs_reshape_type;
125     std::vector<std::string> outputs_value_depend;
126     if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
127                         &outputs_format, &outputs_device_type, &outputs_reshape_type, &outputs_value_depend)) {
128       break;
129     }
130     builder.SetOutputsDeviceType(outputs_device_type);
131     builder.SetOutputsFormat(outputs_format);
132     builder.SetOutputsReshapeType(outputs_reshape_type);
133     kernel_info_list_->emplace_back(builder.Build());
134   }
135 }
136 
GetDynamicFormatPatternKernelInfo(const OpInfo & op_info)137 void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) {
138   OpInfo op_info_new;
139   CreateNewOpInfo(op_info, &op_info_new);
140   GetCommonPatternKernelInfo(op_info_new);
141 }
142 
GetAgnosticPatternKernelInfo(const OpInfo & op_info)143 void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
144   if (op_info.inputs_ptr().size() != 1) {
145     MS_LOG(EXCEPTION) << "AgnosticPattern only support one input.";
146   }
147   auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
148   if (kOpFormatList.find(format) == kOpFormatList.end()) {
149     MS_LOG(INFO) << "Got the unknown format " << format;
150     format = kOpFormat_DEFAULT;
151   }
152   SupportFormat support_format;
153   SupportFormatItem input_item;
154   SupportFormatItem output_item;
155   input_item.assign(op_info.inputs_ptr().size(), format);
156   output_item.assign(op_info.outputs_ptr().size(), format);
157   support_format.input_format.emplace_back(input_item);
158   support_format.output_format.emplace_back(output_item);
159   OpInfo op_info_new;
160   CreateNewOpInfo(op_info, support_format, &op_info_new);
161   GetCommonPatternKernelInfo(op_info_new);
162 }
163 
GetBroadcastPatternKernelInfo(const OpInfo & op_info)164 void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
165   auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_);
166   SupportFormat support_format;
167   broadcast_selecter.GetShapeInfo(&support_format);
168   (void)broadcast_selecter.IsBroadCastSupport5HD(&support_format);
169   (void)broadcast_selecter.IsBroadCastSupportFracZ(&support_format);
170   (void)broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format);
171   (void)broadcast_selecter.IsBroadCastSupportFracNZ(&support_format);
172   (void)broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format);
173   OpInfo op_info_new;
174   CreateNewOpInfo(op_info, support_format, &op_info_new);
175   GetCommonPatternKernelInfo(op_info_new);
176 }
177 
GetReducePatternKernelInfo(const OpInfo & op_info)178 void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
179   auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
180   SupportFormat support_format;
181   reduce_selecter.GetShapeInfo(&support_format);
182   (void)reduce_selecter.IsReduceSupport5HD(&support_format);
183   (void)reduce_selecter.IsReduceSupportFracZ(&support_format);
184   (void)reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format);
185   (void)reduce_selecter.IsReduceSupportFracNZ(&support_format);
186   OpInfo op_info_new;
187   CreateNewOpInfo(op_info, support_format, &op_info_new);
188   GetCommonPatternKernelInfo(op_info_new);
189 }
190 
FilterInVaildKernelInfo(const OpInfo & op_info)191 void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
192   if (kernel_info_list_->empty()) {
193     MS_LOG(INFO) << "Warning: get kernel build info failed. Op name: " << full_name_;
194     return;
195   }
196   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
197   auto dynamic_inputs = GetNodeDynamicInputs();
198   for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
199     if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
200       continue;
201     }
202     if (op_info.need_check_supported()) {
203       if (!TbeCheckSupported(iter)) {
204         continue;
205       }
206     }
207     kernel_info_list.emplace_back(*iter);
208   }
209   if (kernel_info_list.empty()) {
210     MS_LOG(WARNING) << "Tbe kernel info list is empty, all valid kernel info was filtered out. "
211                        "Check the input shape, attrs or other value of node : "
212                     << full_name_;
213   }
214   (*kernel_info_list_) = kernel_info_list;
215 }
216 
FilterInVaildShape(const KernelBuildInfoIter & kernel_build_info_iter,bool is_dynamic_input)217 bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) {
218   MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
219   const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
220   // dynamic input just need to check first input, because other inputs copy from 1th input;
221   auto iter_num =
222     is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size();
223   for (size_t i = 0; i < iter_num; ++i) {
224     auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
225     const auto &format = kernel_build_info_inputs_format.at(i);
226     if (!IsShapeMatchFormat(shape, format)) {
227       return false;
228     }
229   }
230   const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
231   for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
232     auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
233     const auto &format = kernel_build_info_outputs_format[j];
234     if (!IsShapeMatchFormat(shape, format)) {
235       return false;
236     }
237   }
238   return true;
239 }
240 
IsShapeMatchFormat(const std::vector<size_t> & shape,const std::string & format)241 bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
242   if (format == kOpFormat_DEFAULT) {
243     return true;
244   }
245   static const std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
246   // if format is default, it remarkes support all format
247   if (kOpFormatList.find(format) == kOpFormatList.end()) {
248     MS_LOG(EXCEPTION) << "Got the unknown format " << format;
249   }
250   // server not support format with C04 suffix
251   if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
252       kServerNotSupportFormat.end()) {
253     MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
254     return false;
255   }
256   if (format == kOpFormat_FRAC_NZ && shape.size() > kShape2dDims) {
257     return true;
258   }
259   // not support format:
260   // 1 3d formats with shape size > 5
261   if (k3DFormatSet.find(format) != k3DFormatSet.end() && shape.size() > kShape5dDims) {
262     return false;
263   }
264   return true;
265 }
266 
TbeCheckSupported(const KernelBuildInfoIter & kernel_build_info_iter)267 bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) {
268   MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
269   // replace kernel_info with current kernel info
270   auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
271   AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
272   std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
273   bool ret = true;
274   if (!old_build.empty()) {
275     nlohmann::json kernel_json;
276     TbeKernelJsonCreator creator(CHECK_SUPPORTED);
277     ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
278     if (!ret) {
279       MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
280     }
281     ret = AscendKernelBuildClient::Instance().CheckSupported(kernel_json.dump());
282   } else {
283     auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
284     if (!build_manager.AscendOpCheckSupported(cnode_ptr_)) {
285       ret = false;
286     }
287   }
288   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
289   return ret;
290 }
291 
SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo & op_info,mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder * builder)292 void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info,
293                                             mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) {
294   MS_EXCEPTION_IF_NULL(builder);
295   builder->SetProcessor(AICORE);
296   std::string fusion_name = op_info.fusion_type();
297   auto fusion_type = kernel::GetFusionTypeByName(fusion_name);
298   if (fusion_type != UNKNOWN_FUSION_TYPE) {
299     builder->SetFusionType(fusion_type);
300   }
301   builder->SetOpPattern(op_info.op_pattern());
302   builder->SetKernelType(TBE_KERNEL);
303 }
304 
GetNodeDynamicInputs()305 std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
306   // get dynamic inputs
307   auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
308   MS_EXCEPTION_IF_NULL(primitive);
309   std::vector<int64_t> dyn_input_sizes;
310   if (primitive->HasAttr(kAttrDynInputSizes)) {
311     dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
312   }
313   return dyn_input_sizes;
314 }
315 
GenBuilderItem(bool is_input,size_t kernel_build_info_index,size_t real_io_tensor_num,const std::vector<std::shared_ptr<OpIOInfo>> & ios_info,const std::vector<int64_t> & dyn_input_sizes,std::vector<std::string> * formats,std::vector<TypeId> * device_types,std::vector<std::string> * reshape_types,std::vector<std::string> * value_depends)316 bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
317                                      const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
318                                      const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,
319                                      std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types,
320                                      std::vector<std::string> *value_depends) {
321   MS_EXCEPTION_IF_NULL(formats);
322   MS_EXCEPTION_IF_NULL(device_types);
323   MS_EXCEPTION_IF_NULL(reshape_types);
324   MS_EXCEPTION_IF_NULL(value_depends);
325   size_t dynamic_input_index = 0;
326   size_t real_io_tensor_index = 0;
327   size_t io_info_index = 0;
328   size_t io_info_num = ios_info.size();
329   for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) {
330     std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index];
331     const auto &kernel_build_info_dtype = io_info_item->dtypes()[kernel_build_info_index];
332     std::string kernel_build_info_format;
333     if (!io_info_item->formats().empty()) {
334       kernel_build_info_format = io_info_item->formats()[kernel_build_info_index];
335     }
336     const std::string &io_param_type = io_info_item->param_type();
337     auto reshape_type = io_info_item->reshape_type();
338     auto value_depend = io_info_item->value_depend();
339     if (io_param_type == kParamTypeDynamic) {
340       // dynamic io
341       if (is_input) {
342         if (dynamic_input_index >= dyn_input_sizes.size()) {
343           MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index
344                             << ", dyn_input_sizes size: " << dyn_input_sizes.size();
345         }
346         int64_t dynamic_input_size = dyn_input_sizes[dynamic_input_index];
347         for (int64_t i = 0; i < dynamic_input_size; ++i) {
348           device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
349           formats->emplace_back(kernel_build_info_format);
350           reshape_types->emplace_back(reshape_type);
351           value_depends->emplace_back(value_depend);
352         }
353         dynamic_input_index++;
354         real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, LongToSize(dynamic_input_size));
355       } else {
356         if (ios_info.size() != 1) {
357           MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
358         }
359         for (size_t i = 0; i < real_io_tensor_num; ++i) {
360           device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
361           formats->emplace_back(kernel_build_info_format);
362           reshape_types->emplace_back(reshape_type);
363           value_depends->emplace_back(value_depend);
364         }
365         real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, real_io_tensor_num);
366       }
367     } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
368       // require or optional io
369       device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
370       formats->emplace_back(kernel_build_info_format);
371       reshape_types->emplace_back(reshape_type);
372       value_depends->emplace_back(value_depend);
373       real_io_tensor_index++;
374     } else {
375       MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
376     }
377   }
378 
379   if (real_io_tensor_index != real_io_tensor_num) {
380     std::string io_type = is_input ? "inputs " : "outputs";
381     MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num
382                  << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index
383                  << ") != real_io_tensor_num(" << real_io_tensor_num << ")";
384     return false;
385   }
386   return true;
387 }
388 
CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo & op_io_info,const std::vector<std::vector<std::string>> & support_format_item,size_t index,mindspore::kernel::OpIOInfo * op_io_info_new)389 void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
390                                         const std::vector<std::vector<std::string>> &support_format_item, size_t index,
391                                         mindspore::kernel::OpIOInfo *op_io_info_new) {
392   MS_EXCEPTION_IF_NULL(op_io_info_new);
393   op_io_info_new->set_index(op_io_info.index());
394   op_io_info_new->set_name(op_io_info.name());
395   op_io_info_new->set_param_type(op_io_info.param_type());
396   op_io_info_new->set_need_compile(op_io_info.need_compile());
397   op_io_info_new->set_reshape_type(op_io_info.reshape_type());
398   op_io_info_new->set_shape(op_io_info.shape());
399   op_io_info_new->set_value_depend(op_io_info.value_depend());
400   // dtype
401   std::vector<std::string> dtype_new;
402   auto dtype = op_io_info.dtypes();
403   for (size_t i = 0; i < support_format_item.size(); ++i) {
404     dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end());
405   }
406   op_io_info_new->set_dtypes(dtype_new);
407   // format
408   std::vector<std::string> format_new;
409   for (const auto &formats : support_format_item) {
410     auto format = formats.at(index);
411     for (size_t j = 0; j < dtype.size(); ++j) {
412       format_new.emplace_back(format);
413     }
414   }
415   op_io_info_new->set_formats(format_new);
416 }
417 
SplitStrToVec(const std::string & op_select_json_item)418 std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) {
419   const std::map<std::string, std::string> kDynamicFormatMap = {
420     {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}, {"NCDHW", "DefaultFormat"}};
421   if (op_select_json_item.empty()) {
422     MS_LOG(EXCEPTION) << "Op select ret item is null.";
423   }
424   const char space = ' ';
425   const char sep = ',';
426   std::string op_select_tmp = op_select_json_item + ",";
427   std::vector<std::string> ret;
428   auto begin = op_select_tmp.find_first_not_of(space, 0);
429   auto sep_pos = op_select_tmp.find(sep);
430   if (begin >= sep_pos) {
431     MS_LOG(EXCEPTION) << "Select ret json is error.";
432   }
433   while (sep_pos != std::string::npos) {
434     auto obj = op_select_tmp.substr(begin, sep_pos - begin);
435     if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
436       obj = kDynamicFormatMap.at(obj);
437     }
438     ret.emplace_back(obj);
439     begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
440     sep_pos = op_select_tmp.find(sep, begin);
441   }
442   return ret;
443 }
444 
OpSelectFormat()445 std::string TbeKernelSelect::OpSelectFormat() {
446   std::string res_json_str;
447   std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
448   if (!old_build.empty()) {
449     nlohmann::json kernel_json;
450     TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
451     bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
452     if (!ret) {
453       MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
454     }
455     res_json_str = AscendKernelBuildClient::Instance().SelectFormat(kernel_json.dump());
456     if (res_json_str.empty()) {
457       MS_LOG(EXCEPTION) << "Op select format error, input args: " << kernel_json.dump();
458     }
459     if (res_json_str.find("TBEException") != std::string::npos) {
460       MS_LOG(EXCEPTION) << "Dynamic op select failed: " << res_json_str << ", input args: " << kernel_json.dump();
461     }
462   } else {
463     MS_LOG(INFO) << "Format select for node:[" << AnfAlgo::GetCNodeName(cnode_ptr_) << ", "
464                  << cnode_ptr_->fullname_with_scope() << "].";
465     auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
466     res_json_str = build_manager.AscendOpSelectFormat(cnode_ptr_);
467   }
468   return res_json_str;
469 }
470 
CreateNewOpInfo(const mindspore::kernel::OpInfo & op_info,const SupportFormat & support_format,mindspore::kernel::OpInfo * op_info_new)471 void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format,
472                                       mindspore::kernel::OpInfo *op_info_new) {
473   MS_EXCEPTION_IF_NULL(op_info_new);
474   if (support_format.input_format.empty() || support_format.output_format.empty()) {
475     MS_LOG(EXCEPTION) << "Support input format and output format size can not be empty, but the input format size is: "
476                       << support_format.input_format.size()
477                       << ", output format size is: " << support_format.output_format.size();
478   }
479   if (op_info.inputs_ptr().size() != support_format.input_format[0].size() ||
480       op_info.outputs_ptr().size() != support_format.output_format[0].size()) {
481     MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size()
482                       << ", input support size: " << support_format.input_format[0].size()
483                       << ", op info output size: " << op_info.outputs_ptr().size()
484                       << ", output support size: " << support_format.output_format[0].size();
485   }
486   *op_info_new = op_info;
487   op_info_new->ClearInputs();
488   op_info_new->ClearOutputs();
489   for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
490     auto input = op_info.inputs_ptr().at(i);
491     auto input_new = std::make_shared<OpIOInfo>();
492     CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get());
493     op_info_new->add_inputs_ptr(input_new);
494   }
495   for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) {
496     auto output = op_info.outputs_ptr().at(j);
497     auto output_new = std::make_shared<OpIOInfo>();
498     CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get());
499     op_info_new->add_outputs_ptr(output_new);
500   }
501 }
502 
503 struct SelectOpIOInfo {
504   std::string name;
505   std::vector<std::string> dtypes;
506   std::vector<std::string> formats;
507 };
508 
CreateNewOpInfo(const mindspore::kernel::OpInfo & op_info,mindspore::kernel::OpInfo * op_info_new)509 void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
510                                       mindspore::kernel::OpInfo *op_info_new) {
511   MS_EXCEPTION_IF_NULL(op_info_new);
512   auto op_seclect_json = OpSelectFormat();
513   if (!op_seclect_json.empty()) {
514     nlohmann::json json_obj;
515     if (!ParseJson(op_seclect_json, &json_obj)) {
516       MS_LOG(EXCEPTION) << "Parse op_select_json error.";
517     }
518     if (!json_obj.is_object()) {
519       MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json;
520     }
521     std::vector<SelectOpIOInfo> inputs;
522     std::vector<SelectOpIOInfo> outputs;
523     for (const auto &item : json_obj.items()) {
524       const std::string &item_name = item.key();
525       bool is_input = (item_name.find(kPrefixInput) != std::string::npos);
526       bool is_output = (item_name.find(kPrefixOutput) != std::string::npos);
527       if (!is_input && !is_output) {
528         MS_LOG(EXCEPTION) << "op select ret json is error.";
529       }
530       if (is_input) {
531         SelectOpIOInfo select_input;
532         select_input.name = item.value().at(kName);
533         std::string input_dtype_item = item.value().at(kDtype);
534         select_input.dtypes = SplitStrToVec(input_dtype_item);
535         std::string input_format_item = item.value().at(kFormat);
536         select_input.formats = SplitStrToVec(input_format_item);
537         inputs.emplace_back(select_input);
538       } else {
539         SelectOpIOInfo select_output;
540         select_output.name = item.value().at(kName);
541         std::string input_dtype_item = item.value().at(kDtype);
542         select_output.dtypes = SplitStrToVec(input_dtype_item);
543         std::string input_format_item = item.value().at(kFormat);
544         select_output.formats = SplitStrToVec(input_format_item);
545         outputs.emplace_back(select_output);
546       }
547     }
548 
549     if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) {
550       MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register.";
551     }
552 
553     *op_info_new = op_info;
554     op_info_new->ClearInputs();
555     op_info_new->ClearOutputs();
556     for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
557       auto input_new = std::make_shared<OpIOInfo>();
558       CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get());
559       op_info_new->add_inputs_ptr(input_new);
560     }
561     for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) {
562       auto output_new = std::make_shared<OpIOInfo>();
563       CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get());
564       op_info_new->add_outputs_ptr(output_new);
565     }
566   }
567 }
568 
CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo & op_io_info,const std::vector<std::string> & support_dtype,const std::vector<std::string> & support_format,mindspore::kernel::OpIOInfo * op_io_info_new)569 void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
570                                         const std::vector<std::string> &support_dtype,
571                                         const std::vector<std::string> &support_format,
572                                         mindspore::kernel::OpIOInfo *op_io_info_new) {
573   MS_EXCEPTION_IF_NULL(op_io_info_new);
574   op_io_info_new->set_index(op_io_info.index());
575   op_io_info_new->set_name(op_io_info.name());
576   op_io_info_new->set_param_type(op_io_info.param_type());
577   op_io_info_new->set_need_compile(op_io_info.need_compile());
578   op_io_info_new->set_reshape_type(op_io_info.reshape_type());
579   op_io_info_new->set_shape(op_io_info.shape());
580   op_io_info_new->set_value_depend(op_io_info.value_depend());
581   // dtype  && format
582   op_io_info_new->set_dtypes(support_dtype);
583   op_io_info_new->set_formats(support_format);
584 }
585 
PrintSupportedFormat(const SupportFormat & support_format)586 void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
587   if (support_format.input_format.size() != support_format.output_format.size()) {
588     MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
589                       << support_format.output_format.size() << ") size not match.";
590   }
591   for (size_t i = 0; i < support_format.input_format.size(); ++i) {
592     auto input_items = support_format.input_format.at(i);
593     auto output_items = support_format.output_format.at(i);
594     std::string print_str = "[";
595     for (const auto &input : input_items) {
596       print_str.append(input);
597       print_str.append(", ");
598     }
599     print_str.append("] -->");
600     for (const auto &output : output_items) {
601       print_str.append(output);
602       print_str.append(", ");
603     }
604     MS_LOG(INFO) << "Support format: " << print_str;
605   }
606 }
607 }  // namespace mindspore::kernel
608