1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
18
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/oplib/oplib.h"
25 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
26 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
27 #include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
28 #include "backend/kernel_compiler/tbe/ascend_kernel_compile.h"
29 #include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
30 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
31 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
32 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
33 #include "backend/optimizer/common/helper.h"
34 #include "backend/session/anf_runtime_algorithm.h"
35 #include "backend/session/kernel_build_client.h"
36 #include "nlohmann/json.hpp"
37 #include "utils/convert_utils_base.h"
38 #include "utils/json_operation_utils.h"
39
40 namespace mindspore::kernel {
41 constexpr auto kName = "name";
42 constexpr auto kDtype = "dtype";
43 constexpr auto kFormat = "format";
44 constexpr auto kPrefixInput = "input";
45 constexpr auto kPrefixOutput = "output";
46 constexpr char kParamTypeDynamic[] = "dynamic";
47 constexpr char kParamTypeRequre[] = "required";
48 constexpr char kParamTypeOptional[] = "optional";
TbeMetadataInfo(const CNodePtr & kernel_node,std::vector<std::shared_ptr<KernelBuildInfo>> * kernel_info_list)49 void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
50 auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
51 tbe_selecter.TbeMetadataInfoEx();
52 }
53
TbeKernelSelect(CNodePtr kernel_node,std::vector<std::shared_ptr<KernelBuildInfo>> * kernel_info_list)54 TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list)
55 : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {}
56
TbeMetadataInfoEx()57 void TbeKernelSelect::TbeMetadataInfoEx() {
58 MS_EXCEPTION_IF_NULL(cnode_ptr_);
59 MS_EXCEPTION_IF_NULL(kernel_info_list_);
60 node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
61 full_name_ = cnode_ptr_->fullname_with_scope();
62
63 auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
64 if (!op_info_ptr) {
65 return;
66 }
67 if (!TbePropertyChecker::CheckTbeProperties(cnode_ptr_)) {
68 MS_LOG(INFO) << "Warning: node(" << full_name_ << ") is not supported by tbe ai_core.";
69 return;
70 }
71
72 if (op_info_ptr->is_dynamic_format()) {
73 GetDynamicFormatPatternKernelInfo(*op_info_ptr);
74 } else {
75 OpPattern pattern = op_info_ptr->op_pattern();
76 if (pattern == kCommonPattern) {
77 GetCommonPatternKernelInfo(*op_info_ptr);
78 } else if (pattern == kFormatAgnosticPattern) {
79 GetAgnosticPatternKernelInfo(*op_info_ptr);
80 } else if (pattern == kBroadcastPattern) {
81 GetBroadcastPatternKernelInfo(*op_info_ptr);
82 } else if (pattern == kReducePattern) {
83 GetReducePatternKernelInfo(*op_info_ptr);
84 } else {
85 MS_LOG(INFO) << "Warning: op pattern is invailed.";
86 }
87 }
88 // check support
89 FilterInVaildKernelInfo(*op_info_ptr);
90 }
91
GetCommonPatternKernelInfo(const OpInfo & op_info)92 void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
93 auto dyn_input_sizes = GetNodeDynamicInputs();
94 // get real input/output num
95 size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
96 const auto inputs_info = op_info.inputs_ptr();
97 size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
98 const auto outputs_info = op_info.outputs_ptr();
99 if (inputs_info.empty() && outputs_info.empty()) {
100 MS_LOG(EXCEPTION) << AnfAlgo::GetCNodeName(cnode_ptr_) << "'s op info input & output is null, please check.";
101 }
102 // create kernel build info from opinfo
103 size_t kernel_build_info_num =
104 inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size();
105 for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) {
106 auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
107 SetTbeBuildCommonInfo(op_info, &builder);
108 std::vector<std::string> inputs_format;
109 std::vector<TypeId> inputs_device_type;
110 std::vector<std::string> inputs_reshape_type;
111 std::vector<std::string> inputs_value_depend;
112 // input
113 if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
114 &inputs_format, &inputs_device_type, &inputs_reshape_type, &inputs_value_depend)) {
115 break;
116 }
117 builder.SetInputsDeviceType(inputs_device_type);
118 builder.SetInputsFormat(inputs_format);
119 builder.SetInputsReshapeType(inputs_reshape_type);
120 builder.SetInputsValueDepend(inputs_value_depend);
121 // output
122 std::vector<std::string> outputs_format;
123 std::vector<TypeId> outputs_device_type;
124 std::vector<std::string> outputs_reshape_type;
125 std::vector<std::string> outputs_value_depend;
126 if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
127 &outputs_format, &outputs_device_type, &outputs_reshape_type, &outputs_value_depend)) {
128 break;
129 }
130 builder.SetOutputsDeviceType(outputs_device_type);
131 builder.SetOutputsFormat(outputs_format);
132 builder.SetOutputsReshapeType(outputs_reshape_type);
133 kernel_info_list_->emplace_back(builder.Build());
134 }
135 }
136
GetDynamicFormatPatternKernelInfo(const OpInfo & op_info)137 void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) {
138 OpInfo op_info_new;
139 CreateNewOpInfo(op_info, &op_info_new);
140 GetCommonPatternKernelInfo(op_info_new);
141 }
142
GetAgnosticPatternKernelInfo(const OpInfo & op_info)143 void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
144 if (op_info.inputs_ptr().size() != 1) {
145 MS_LOG(EXCEPTION) << "AgnosticPattern only support one input.";
146 }
147 auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
148 if (kOpFormatList.find(format) == kOpFormatList.end()) {
149 MS_LOG(INFO) << "Got the unknown format " << format;
150 format = kOpFormat_DEFAULT;
151 }
152 SupportFormat support_format;
153 SupportFormatItem input_item;
154 SupportFormatItem output_item;
155 input_item.assign(op_info.inputs_ptr().size(), format);
156 output_item.assign(op_info.outputs_ptr().size(), format);
157 support_format.input_format.emplace_back(input_item);
158 support_format.output_format.emplace_back(output_item);
159 OpInfo op_info_new;
160 CreateNewOpInfo(op_info, support_format, &op_info_new);
161 GetCommonPatternKernelInfo(op_info_new);
162 }
163
GetBroadcastPatternKernelInfo(const OpInfo & op_info)164 void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
165 auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_);
166 SupportFormat support_format;
167 broadcast_selecter.GetShapeInfo(&support_format);
168 (void)broadcast_selecter.IsBroadCastSupport5HD(&support_format);
169 (void)broadcast_selecter.IsBroadCastSupportFracZ(&support_format);
170 (void)broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format);
171 (void)broadcast_selecter.IsBroadCastSupportFracNZ(&support_format);
172 (void)broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format);
173 OpInfo op_info_new;
174 CreateNewOpInfo(op_info, support_format, &op_info_new);
175 GetCommonPatternKernelInfo(op_info_new);
176 }
177
GetReducePatternKernelInfo(const OpInfo & op_info)178 void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
179 auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
180 SupportFormat support_format;
181 reduce_selecter.GetShapeInfo(&support_format);
182 (void)reduce_selecter.IsReduceSupport5HD(&support_format);
183 (void)reduce_selecter.IsReduceSupportFracZ(&support_format);
184 (void)reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format);
185 (void)reduce_selecter.IsReduceSupportFracNZ(&support_format);
186 OpInfo op_info_new;
187 CreateNewOpInfo(op_info, support_format, &op_info_new);
188 GetCommonPatternKernelInfo(op_info_new);
189 }
190
FilterInVaildKernelInfo(const OpInfo & op_info)191 void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
192 if (kernel_info_list_->empty()) {
193 MS_LOG(INFO) << "Warning: get kernel build info failed. Op name: " << full_name_;
194 return;
195 }
196 std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
197 auto dynamic_inputs = GetNodeDynamicInputs();
198 for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
199 if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
200 continue;
201 }
202 if (op_info.need_check_supported()) {
203 if (!TbeCheckSupported(iter)) {
204 continue;
205 }
206 }
207 kernel_info_list.emplace_back(*iter);
208 }
209 if (kernel_info_list.empty()) {
210 MS_LOG(WARNING) << "Tbe kernel info list is empty, all valid kernel info was filtered out. "
211 "Check the input shape, attrs or other value of node : "
212 << full_name_;
213 }
214 (*kernel_info_list_) = kernel_info_list;
215 }
216
FilterInVaildShape(const KernelBuildInfoIter & kernel_build_info_iter,bool is_dynamic_input)217 bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) {
218 MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
219 const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
220 // dynamic input just need to check first input, because other inputs copy from 1th input;
221 auto iter_num =
222 is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size();
223 for (size_t i = 0; i < iter_num; ++i) {
224 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
225 const auto &format = kernel_build_info_inputs_format.at(i);
226 if (!IsShapeMatchFormat(shape, format)) {
227 return false;
228 }
229 }
230 const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
231 for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
232 auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
233 const auto &format = kernel_build_info_outputs_format[j];
234 if (!IsShapeMatchFormat(shape, format)) {
235 return false;
236 }
237 }
238 return true;
239 }
240
IsShapeMatchFormat(const std::vector<size_t> & shape,const std::string & format)241 bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
242 if (format == kOpFormat_DEFAULT) {
243 return true;
244 }
245 static const std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
246 // if format is default, it remarkes support all format
247 if (kOpFormatList.find(format) == kOpFormatList.end()) {
248 MS_LOG(EXCEPTION) << "Got the unknown format " << format;
249 }
250 // server not support format with C04 suffix
251 if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
252 kServerNotSupportFormat.end()) {
253 MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
254 return false;
255 }
256 if (format == kOpFormat_FRAC_NZ && shape.size() > kShape2dDims) {
257 return true;
258 }
259 // not support format:
260 // 1 3d formats with shape size > 5
261 if (k3DFormatSet.find(format) != k3DFormatSet.end() && shape.size() > kShape5dDims) {
262 return false;
263 }
264 return true;
265 }
266
TbeCheckSupported(const KernelBuildInfoIter & kernel_build_info_iter)267 bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) {
268 MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
269 // replace kernel_info with current kernel info
270 auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
271 AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
272 std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
273 bool ret = true;
274 if (!old_build.empty()) {
275 nlohmann::json kernel_json;
276 TbeKernelJsonCreator creator(CHECK_SUPPORTED);
277 ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
278 if (!ret) {
279 MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
280 }
281 ret = AscendKernelBuildClient::Instance().CheckSupported(kernel_json.dump());
282 } else {
283 auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
284 if (!build_manager.AscendOpCheckSupported(cnode_ptr_)) {
285 ret = false;
286 }
287 }
288 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
289 return ret;
290 }
291
SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo & op_info,mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder * builder)292 void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info,
293 mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) {
294 MS_EXCEPTION_IF_NULL(builder);
295 builder->SetProcessor(AICORE);
296 std::string fusion_name = op_info.fusion_type();
297 auto fusion_type = kernel::GetFusionTypeByName(fusion_name);
298 if (fusion_type != UNKNOWN_FUSION_TYPE) {
299 builder->SetFusionType(fusion_type);
300 }
301 builder->SetOpPattern(op_info.op_pattern());
302 builder->SetKernelType(TBE_KERNEL);
303 }
304
GetNodeDynamicInputs()305 std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
306 // get dynamic inputs
307 auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
308 MS_EXCEPTION_IF_NULL(primitive);
309 std::vector<int64_t> dyn_input_sizes;
310 if (primitive->HasAttr(kAttrDynInputSizes)) {
311 dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
312 }
313 return dyn_input_sizes;
314 }
315
GenBuilderItem(bool is_input,size_t kernel_build_info_index,size_t real_io_tensor_num,const std::vector<std::shared_ptr<OpIOInfo>> & ios_info,const std::vector<int64_t> & dyn_input_sizes,std::vector<std::string> * formats,std::vector<TypeId> * device_types,std::vector<std::string> * reshape_types,std::vector<std::string> * value_depends)316 bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
317 const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
318 const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,
319 std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types,
320 std::vector<std::string> *value_depends) {
321 MS_EXCEPTION_IF_NULL(formats);
322 MS_EXCEPTION_IF_NULL(device_types);
323 MS_EXCEPTION_IF_NULL(reshape_types);
324 MS_EXCEPTION_IF_NULL(value_depends);
325 size_t dynamic_input_index = 0;
326 size_t real_io_tensor_index = 0;
327 size_t io_info_index = 0;
328 size_t io_info_num = ios_info.size();
329 for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) {
330 std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index];
331 const auto &kernel_build_info_dtype = io_info_item->dtypes()[kernel_build_info_index];
332 std::string kernel_build_info_format;
333 if (!io_info_item->formats().empty()) {
334 kernel_build_info_format = io_info_item->formats()[kernel_build_info_index];
335 }
336 const std::string &io_param_type = io_info_item->param_type();
337 auto reshape_type = io_info_item->reshape_type();
338 auto value_depend = io_info_item->value_depend();
339 if (io_param_type == kParamTypeDynamic) {
340 // dynamic io
341 if (is_input) {
342 if (dynamic_input_index >= dyn_input_sizes.size()) {
343 MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index
344 << ", dyn_input_sizes size: " << dyn_input_sizes.size();
345 }
346 int64_t dynamic_input_size = dyn_input_sizes[dynamic_input_index];
347 for (int64_t i = 0; i < dynamic_input_size; ++i) {
348 device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
349 formats->emplace_back(kernel_build_info_format);
350 reshape_types->emplace_back(reshape_type);
351 value_depends->emplace_back(value_depend);
352 }
353 dynamic_input_index++;
354 real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, LongToSize(dynamic_input_size));
355 } else {
356 if (ios_info.size() != 1) {
357 MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
358 }
359 for (size_t i = 0; i < real_io_tensor_num; ++i) {
360 device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
361 formats->emplace_back(kernel_build_info_format);
362 reshape_types->emplace_back(reshape_type);
363 value_depends->emplace_back(value_depend);
364 }
365 real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, real_io_tensor_num);
366 }
367 } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
368 // require or optional io
369 device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
370 formats->emplace_back(kernel_build_info_format);
371 reshape_types->emplace_back(reshape_type);
372 value_depends->emplace_back(value_depend);
373 real_io_tensor_index++;
374 } else {
375 MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
376 }
377 }
378
379 if (real_io_tensor_index != real_io_tensor_num) {
380 std::string io_type = is_input ? "inputs " : "outputs";
381 MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num
382 << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index
383 << ") != real_io_tensor_num(" << real_io_tensor_num << ")";
384 return false;
385 }
386 return true;
387 }
388
CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo & op_io_info,const std::vector<std::vector<std::string>> & support_format_item,size_t index,mindspore::kernel::OpIOInfo * op_io_info_new)389 void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
390 const std::vector<std::vector<std::string>> &support_format_item, size_t index,
391 mindspore::kernel::OpIOInfo *op_io_info_new) {
392 MS_EXCEPTION_IF_NULL(op_io_info_new);
393 op_io_info_new->set_index(op_io_info.index());
394 op_io_info_new->set_name(op_io_info.name());
395 op_io_info_new->set_param_type(op_io_info.param_type());
396 op_io_info_new->set_need_compile(op_io_info.need_compile());
397 op_io_info_new->set_reshape_type(op_io_info.reshape_type());
398 op_io_info_new->set_shape(op_io_info.shape());
399 op_io_info_new->set_value_depend(op_io_info.value_depend());
400 // dtype
401 std::vector<std::string> dtype_new;
402 auto dtype = op_io_info.dtypes();
403 for (size_t i = 0; i < support_format_item.size(); ++i) {
404 dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end());
405 }
406 op_io_info_new->set_dtypes(dtype_new);
407 // format
408 std::vector<std::string> format_new;
409 for (const auto &formats : support_format_item) {
410 auto format = formats.at(index);
411 for (size_t j = 0; j < dtype.size(); ++j) {
412 format_new.emplace_back(format);
413 }
414 }
415 op_io_info_new->set_formats(format_new);
416 }
417
SplitStrToVec(const std::string & op_select_json_item)418 std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) {
419 const std::map<std::string, std::string> kDynamicFormatMap = {
420 {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}, {"NCDHW", "DefaultFormat"}};
421 if (op_select_json_item.empty()) {
422 MS_LOG(EXCEPTION) << "Op select ret item is null.";
423 }
424 const char space = ' ';
425 const char sep = ',';
426 std::string op_select_tmp = op_select_json_item + ",";
427 std::vector<std::string> ret;
428 auto begin = op_select_tmp.find_first_not_of(space, 0);
429 auto sep_pos = op_select_tmp.find(sep);
430 if (begin >= sep_pos) {
431 MS_LOG(EXCEPTION) << "Select ret json is error.";
432 }
433 while (sep_pos != std::string::npos) {
434 auto obj = op_select_tmp.substr(begin, sep_pos - begin);
435 if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
436 obj = kDynamicFormatMap.at(obj);
437 }
438 ret.emplace_back(obj);
439 begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
440 sep_pos = op_select_tmp.find(sep, begin);
441 }
442 return ret;
443 }
444
OpSelectFormat()445 std::string TbeKernelSelect::OpSelectFormat() {
446 std::string res_json_str;
447 std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
448 if (!old_build.empty()) {
449 nlohmann::json kernel_json;
450 TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
451 bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
452 if (!ret) {
453 MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
454 }
455 res_json_str = AscendKernelBuildClient::Instance().SelectFormat(kernel_json.dump());
456 if (res_json_str.empty()) {
457 MS_LOG(EXCEPTION) << "Op select format error, input args: " << kernel_json.dump();
458 }
459 if (res_json_str.find("TBEException") != std::string::npos) {
460 MS_LOG(EXCEPTION) << "Dynamic op select failed: " << res_json_str << ", input args: " << kernel_json.dump();
461 }
462 } else {
463 MS_LOG(INFO) << "Format select for node:[" << AnfAlgo::GetCNodeName(cnode_ptr_) << ", "
464 << cnode_ptr_->fullname_with_scope() << "].";
465 auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
466 res_json_str = build_manager.AscendOpSelectFormat(cnode_ptr_);
467 }
468 return res_json_str;
469 }
470
CreateNewOpInfo(const mindspore::kernel::OpInfo & op_info,const SupportFormat & support_format,mindspore::kernel::OpInfo * op_info_new)471 void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format,
472 mindspore::kernel::OpInfo *op_info_new) {
473 MS_EXCEPTION_IF_NULL(op_info_new);
474 if (support_format.input_format.empty() || support_format.output_format.empty()) {
475 MS_LOG(EXCEPTION) << "Support input format and output format size can not be empty, but the input format size is: "
476 << support_format.input_format.size()
477 << ", output format size is: " << support_format.output_format.size();
478 }
479 if (op_info.inputs_ptr().size() != support_format.input_format[0].size() ||
480 op_info.outputs_ptr().size() != support_format.output_format[0].size()) {
481 MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size()
482 << ", input support size: " << support_format.input_format[0].size()
483 << ", op info output size: " << op_info.outputs_ptr().size()
484 << ", output support size: " << support_format.output_format[0].size();
485 }
486 *op_info_new = op_info;
487 op_info_new->ClearInputs();
488 op_info_new->ClearOutputs();
489 for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
490 auto input = op_info.inputs_ptr().at(i);
491 auto input_new = std::make_shared<OpIOInfo>();
492 CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get());
493 op_info_new->add_inputs_ptr(input_new);
494 }
495 for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) {
496 auto output = op_info.outputs_ptr().at(j);
497 auto output_new = std::make_shared<OpIOInfo>();
498 CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get());
499 op_info_new->add_outputs_ptr(output_new);
500 }
501 }
502
503 struct SelectOpIOInfo {
504 std::string name;
505 std::vector<std::string> dtypes;
506 std::vector<std::string> formats;
507 };
508
CreateNewOpInfo(const mindspore::kernel::OpInfo & op_info,mindspore::kernel::OpInfo * op_info_new)509 void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
510 mindspore::kernel::OpInfo *op_info_new) {
511 MS_EXCEPTION_IF_NULL(op_info_new);
512 auto op_seclect_json = OpSelectFormat();
513 if (!op_seclect_json.empty()) {
514 nlohmann::json json_obj;
515 if (!ParseJson(op_seclect_json, &json_obj)) {
516 MS_LOG(EXCEPTION) << "Parse op_select_json error.";
517 }
518 if (!json_obj.is_object()) {
519 MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json;
520 }
521 std::vector<SelectOpIOInfo> inputs;
522 std::vector<SelectOpIOInfo> outputs;
523 for (const auto &item : json_obj.items()) {
524 const std::string &item_name = item.key();
525 bool is_input = (item_name.find(kPrefixInput) != std::string::npos);
526 bool is_output = (item_name.find(kPrefixOutput) != std::string::npos);
527 if (!is_input && !is_output) {
528 MS_LOG(EXCEPTION) << "op select ret json is error.";
529 }
530 if (is_input) {
531 SelectOpIOInfo select_input;
532 select_input.name = item.value().at(kName);
533 std::string input_dtype_item = item.value().at(kDtype);
534 select_input.dtypes = SplitStrToVec(input_dtype_item);
535 std::string input_format_item = item.value().at(kFormat);
536 select_input.formats = SplitStrToVec(input_format_item);
537 inputs.emplace_back(select_input);
538 } else {
539 SelectOpIOInfo select_output;
540 select_output.name = item.value().at(kName);
541 std::string input_dtype_item = item.value().at(kDtype);
542 select_output.dtypes = SplitStrToVec(input_dtype_item);
543 std::string input_format_item = item.value().at(kFormat);
544 select_output.formats = SplitStrToVec(input_format_item);
545 outputs.emplace_back(select_output);
546 }
547 }
548
549 if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) {
550 MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register.";
551 }
552
553 *op_info_new = op_info;
554 op_info_new->ClearInputs();
555 op_info_new->ClearOutputs();
556 for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
557 auto input_new = std::make_shared<OpIOInfo>();
558 CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get());
559 op_info_new->add_inputs_ptr(input_new);
560 }
561 for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) {
562 auto output_new = std::make_shared<OpIOInfo>();
563 CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get());
564 op_info_new->add_outputs_ptr(output_new);
565 }
566 }
567 }
568
CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo & op_io_info,const std::vector<std::string> & support_dtype,const std::vector<std::string> & support_format,mindspore::kernel::OpIOInfo * op_io_info_new)569 void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
570 const std::vector<std::string> &support_dtype,
571 const std::vector<std::string> &support_format,
572 mindspore::kernel::OpIOInfo *op_io_info_new) {
573 MS_EXCEPTION_IF_NULL(op_io_info_new);
574 op_io_info_new->set_index(op_io_info.index());
575 op_io_info_new->set_name(op_io_info.name());
576 op_io_info_new->set_param_type(op_io_info.param_type());
577 op_io_info_new->set_need_compile(op_io_info.need_compile());
578 op_io_info_new->set_reshape_type(op_io_info.reshape_type());
579 op_io_info_new->set_shape(op_io_info.shape());
580 op_io_info_new->set_value_depend(op_io_info.value_depend());
581 // dtype && format
582 op_io_info_new->set_dtypes(support_dtype);
583 op_io_info_new->set_formats(support_format);
584 }
585
PrintSupportedFormat(const SupportFormat & support_format)586 void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
587 if (support_format.input_format.size() != support_format.output_format.size()) {
588 MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
589 << support_format.output_format.size() << ") size not match.";
590 }
591 for (size_t i = 0; i < support_format.input_format.size(); ++i) {
592 auto input_items = support_format.input_format.at(i);
593 auto output_items = support_format.output_format.at(i);
594 std::string print_str = "[";
595 for (const auto &input : input_items) {
596 print_str.append(input);
597 print_str.append(", ");
598 }
599 print_str.append("] -->");
600 for (const auto &output : output_items) {
601 print_str.append(output);
602 print_str.append(", ");
603 }
604 MS_LOG(INFO) << "Support format: " << print_str;
605 }
606 }
607 } // namespace mindspore::kernel
608