1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h"
18 #include <memory>
19 #include <map>
20 #include <utility>
21 #include <algorithm>
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/tbe/tbe_adapter.h"
25 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
26 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
27 #include "utils/ms_context.h"
28 #include "runtime/dev.h"
29 #include "utils/ms_utils.h"
30 #include "utils/json_operation_utils.h"
31 #include "utils/convert_utils.h"
32 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_utils.h"
33
34 namespace mindspore::kernel {
35 namespace {
36 static std::unordered_map<std::string, ATTR_DTYPE> type_attr_dtype_map = {
37 {kVTypeInt, ATTR_DTYPE::ATTR_INT32},
38 {kVTypeInt64, ATTR_DTYPE::ATTR_INT64},
39 {kVTypeStr, ATTR_DTYPE::ATTR_STR},
40 {kVTypeBool, ATTR_DTYPE::ATTR_BOOL},
41 {kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32},
42 {kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32},
43 {kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32},
44 {kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64},
45 {kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}};
46
47 static std::map<ATTR_DTYPE, std::string> tbe_attr_dtype_to_string_map = {
48 {ATTR_INT8, "int8"},
49 {ATTR_UINT8, "uint8"},
50 {ATTR_INT16, "int16"},
51 {ATTR_UINT16, "uint16"},
52 {ATTR_INT32, "int32"},
53 {ATTR_UINT32, "uint32"},
54 {ATTR_INT64, "int64"},
55 {ATTR_UINT64, "uint64"},
56 {ATTR_FLOAT32, "float32"},
57 {ATTR_DOUBLE, "double"},
58 {ATTR_BOOL, "bool"},
59 {ATTR_STR, "str"},
60 {ATTR_LIST_INT8, "list_int8"},
61 {ATTR_LIST_UINT8, "list_uint8"},
62 {ATTR_LIST_INT16, "list_int16"},
63 {ATTR_LIST_UINT16, "list_uint16"},
64 {ATTR_LIST_INT32, "list_int32"},
65 {ATTR_LIST_UINT32, "list_uint32"},
66 {ATTR_LIST_INT64, "list_int64"},
67 {ATTR_LIST_UINT64, "list_uint64"},
68 {ATTR_LIST_FLOAT32, "list_float32"},
69 {ATTR_LIST_DOUBLE, "list_double"},
70 {ATTR_LIST_BOOL, "list_bool"},
71 {ATTR_LIST_STR, "list_str"},
72 {ATTR_LIST_LIST_INT64, "list_list_int64"},
73 {ATTR_LIST_LIST_FLOAT, "list_list_float"},
74 };
75
ParseListIntValue(const mindspore::ValuePtr & value,std::vector<int64_t> * attr_value)76 bool ParseListIntValue(const mindspore::ValuePtr &value, std::vector<int64_t> *attr_value) {
77 auto value_type = value->type();
78 if (value_type == nullptr) {
79 MS_LOG(ERROR) << "Value's type is null.";
80 return false;
81 }
82 if (value_type->ToString() == kVTypeInt64) {
83 attr_value->push_back(GetValue<int64_t>(value));
84 } else {
85 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
86 if (!vec.empty()) {
87 if (vec[0]->isa<Int32Imm>()) {
88 std::vector<int32_t> attr_value_me = GetValue<std::vector<int32_t>>(value);
89 (void)std::transform(attr_value_me.begin(), attr_value_me.end(), std::back_inserter(*attr_value),
90 [](const int &value) { return static_cast<int64_t>(value); });
91 } else {
92 *attr_value = GetValue<std::vector<int64_t>>(value);
93 }
94 }
95 }
96 return true;
97 }
98
ParseAttrValue(const std::string & type,const mindspore::ValuePtr & value,nlohmann::json * attr_obj)99 bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, nlohmann::json *attr_obj) {
100 MS_EXCEPTION_IF_NULL(attr_obj);
101 if (value == nullptr) {
102 MS_LOG(ERROR) << "Node's attr value is null.";
103 return false;
104 }
105 auto result = type_attr_dtype_map.find(type);
106 if (result == type_attr_dtype_map.end()) {
107 MS_LOG(ERROR) << "Type: " << type << "not support";
108 return false;
109 }
110
111 auto dtype_string = tbe_attr_dtype_to_string_map.find(result->second);
112 if (dtype_string == tbe_attr_dtype_to_string_map.end()) {
113 MS_LOG(ERROR) << "Can't convert attr dtype " << result->second << " to string";
114 return false;
115 }
116 (*attr_obj)[kJDtype] = dtype_string->second;
117
118 switch (result->second) {
119 case ATTR_DTYPE::ATTR_INT32:
120 (*attr_obj)[kJValue] = value->isa<Int32Imm>() ? GetValue<int>(value) : GetValue<int64_t>(value);
121 break;
122 case ATTR_DTYPE::ATTR_INT64:
123 (*attr_obj)[kJValue] = GetValue<int64_t>(value);
124 break;
125 case ATTR_DTYPE::ATTR_STR: {
126 auto attr_value = GetValue<std::string>(value);
127 (*attr_obj)[kJValue] = attr_value == kOpFormat_FRAC_Z ? kJOpFormat_FRACTAL_Z : attr_value;
128 break;
129 }
130 case ATTR_DTYPE::ATTR_BOOL:
131 (*attr_obj)[kJValue] = GetValue<bool>(value);
132 break;
133 case ATTR_DTYPE::ATTR_FLOAT32:
134 (*attr_obj)[kJValue] = GetValue<float>(value);
135 break;
136 case ATTR_DTYPE::ATTR_LIST_INT32: {
137 std::vector<int64_t> attr_value;
138 if (!ParseListIntValue(value, &attr_value)) {
139 MS_LOG(ERROR) << "Parse list_value failed, maybe the input is a nullptr.";
140 return false;
141 }
142 (*attr_obj)[kJValue] = attr_value;
143 break;
144 }
145 case ATTR_DTYPE::ATTR_LIST_FLOAT32: {
146 auto value_type = value->type();
147 if (value_type == nullptr) {
148 MS_LOG(ERROR) << "Value's type is null.";
149 return false;
150 }
151 (*attr_obj)[kJValue] = value_type->ToString() == kVTypeFloat ? std::vector<float>{GetValue<float>(value)}
152 : GetValue<std::vector<float>>(value);
153 break;
154 }
155 case ATTR_DTYPE::ATTR_LIST_UINT64:
156 (*attr_obj)[kJValue] = GetValue<std::vector<size_t>>(value);
157 break;
158 case ATTR_DTYPE::ATTR_LIST_LIST_INT64:
159 (*attr_obj)[kJValue] = GetValue<std::vector<std::vector<int64_t>>>(value);
160 break;
161
162 default:
163 MS_LOG(ERROR) << "Type: " << type << "not support";
164 return false;
165 }
166 return true;
167 }
168
ParseAttrDefaultValue(const std::string & type,const std::string & value,nlohmann::json * attr_obj)169 bool ParseAttrDefaultValue(const std::string &type, const std::string &value, nlohmann::json *attr_obj) {
170 MS_EXCEPTION_IF_NULL(attr_obj);
171 auto result = type_attr_dtype_map.find(type);
172 if (result == type_attr_dtype_map.end()) {
173 MS_LOG(ERROR) << "Type: " << type << "not support";
174 return false;
175 }
176
177 auto dtype_string = tbe_attr_dtype_to_string_map.find(result->second);
178 if (dtype_string == tbe_attr_dtype_to_string_map.end()) {
179 MS_LOG(ERROR) << "Can't convert attr dtype " << result->second << " to string";
180 return false;
181 }
182 (*attr_obj)[kJDtype] = dtype_string->second;
183
184 switch (result->second) {
185 case ATTR_DTYPE::ATTR_INT32:
186 (*attr_obj)[kJValue] = std::stoi(value);
187 break;
188 case ATTR_DTYPE::ATTR_INT64:
189 (*attr_obj)[kJValue] = std::stoll(value);
190 break;
191 case ATTR_DTYPE::ATTR_STR:
192 (*attr_obj)[kJValue] = value;
193 break;
194 case ATTR_DTYPE::ATTR_BOOL: {
195 bool attr_value = false;
196 std::istringstream(value) >> std::boolalpha >> attr_value;
197 (*attr_obj)[kJValue] = attr_value;
198 break;
199 }
200 case ATTR_DTYPE::ATTR_FLOAT32:
201 (*attr_obj)[kJValue] = std::stof(value);
202 break;
203 case ATTR_DTYPE::ATTR_LIST_INT32: {
204 std::stringstream string_value(value);
205 std::string list_elem;
206 std::vector<int64_t> attrs_value;
207 while (std::getline(string_value, list_elem, ',')) {
208 attrs_value.push_back(std::stoi(list_elem));
209 }
210 (*attr_obj)[kJValue] = attrs_value;
211 break;
212 }
213 default:
214 MS_LOG(ERROR) << "Type: " << type << "not support";
215 return false;
216 }
217 return true;
218 }
219 } // namespace
220
GenComputeJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)221 bool TbeJsonCreator::GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
222 MS_EXCEPTION_IF_NULL(anf_node);
223 MS_EXCEPTION_IF_NULL(compute_json);
224 MS_LOG(DEBUG) << "Start.";
225
226 if (!GenInputsJson(anf_node, compute_json)) {
227 MS_LOG(ERROR) << "generate inputs json failed, node full name:" << anf_node->fullname_with_scope();
228 return false;
229 }
230 if (!GenOutputsJson(anf_node, compute_json)) {
231 MS_LOG(ERROR) << "generate outputs json failed, node full name:" << anf_node->fullname_with_scope();
232 return false;
233 }
234 GenOutputDataDescJson(anf_node, compute_json);
235 GenAttrsDescJson(anf_node, compute_json);
236 GenComputeCommonJson(anf_node, compute_json);
237 GenOtherJson(anf_node, compute_json);
238 MS_LOG(DEBUG) << "End.";
239 return true;
240 }
241
GenFusionOpName(nlohmann::json * kernel_json,std::string prefix)242 void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string prefix) {
243 json_name_.clear();
244 json_hash_ = GenJsonHash((*kernel_json));
245 auto context_ptr = MsContext::GetInstance();
246 MS_EXCEPTION_IF_NULL(context_ptr);
247 json_name_ = std::move(prefix);
248 auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
249 for (auto node_json : (*kernel_json)[kJOpList]) {
250 if (GetJsonValue<std::string>(node_json, kJType) != kJData) {
251 json_name_.append(node_json[kJFuncName]);
252 json_name_.append("_");
253 }
254 }
255 json_name_ = json_name_ + std::to_string(json_hash_) + "_" + std::to_string(device_id);
256 MS_LOG(DEBUG) << "Generate Json name: " << json_name_;
257 (*kernel_json)[kJFusionOpName] = json_name_;
258 }
259
DeleteDescName(nlohmann::json * desc_jsons)260 void TbeJsonCreator::DeleteDescName(nlohmann::json *desc_jsons) {
261 for (auto &desc_json : (*desc_jsons)) {
262 if (desc_json.is_array()) {
263 for (auto &desc_item : desc_json) {
264 desc_item.erase(kJName);
265 }
266 } else {
267 desc_json.erase(kJName);
268 }
269 }
270 }
271
GenJsonHash(nlohmann::json tbe_json)272 size_t TbeJsonCreator::GenJsonHash(nlohmann::json tbe_json) {
273 auto &op_lists = tbe_json.at(kJOpList);
274 for (auto &op : op_lists) {
275 op.erase(kJName);
276 op.erase(kJOriName);
277 op.erase(kJPattern);
278 DeleteDescName(&op.at(kJOutputDesc));
279 if (op[kJType] != kJData) {
280 DeleteDescName(&op.at(kJInputDesc));
281 }
282 }
283 return std::hash<std::string>()(op_lists.dump());
284 }
285
AddOpNameForComputeNode(nlohmann::json * kernel_json)286 void TbeJsonCreator::AddOpNameForComputeNode(nlohmann::json *kernel_json) {
287 auto op_name = GetJsonValue<std::string>((*kernel_json), kJFusionOpName);
288 for (auto &node_json : (*kernel_json).at(kJOpList)) {
289 // compute node
290 if (GetJsonValue<std::string>(node_json, kJType) != kJData) {
291 node_json[kJOpName] = op_name;
292 }
293 }
294 }
295
GenAttrsJson(const AnfNodePtr & anf_node,const OpInfoPtr & op_info,nlohmann::json * attrs_json)296 void TbeJsonCreator::GenAttrsJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json) {
297 MS_EXCEPTION_IF_NULL(anf_node);
298 MS_EXCEPTION_IF_NULL(op_info);
299 MS_EXCEPTION_IF_NULL(attrs_json);
300 auto attrs_ptr = op_info->attrs_ptr();
301 if (!AttrsJsonPreProcessing(anf_node, &attrs_ptr, attrs_json)) {
302 MS_LOG(EXCEPTION) << "PreProcessing node attr error, node: " << anf_node->fullname_with_scope();
303 }
304
305 std::string op_name = AnfAlgo::GetCNodeName(anf_node);
306 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
307 MS_EXCEPTION_IF_NULL(primitive);
308 for (const auto &attr_ptr : attrs_ptr) {
309 std::string attr_name = attr_ptr->name();
310 nlohmann::json attr_obj;
311 attr_obj[kJName] = attr_name;
312 if (primitive->GetAttr(attr_name) != nullptr) {
313 if (!ParseAttrValue(attr_ptr->type(), primitive->GetAttr(attr_name), &attr_obj)) {
314 MS_LOG(EXCEPTION) << "op [ " << op_info->op_name() << " ]'s attr [ " << attr_name << " ] generates failed";
315 }
316 attr_obj[kJValid] = true;
317 } else {
318 auto default_value = attr_ptr->default_value();
319 if (!default_value.empty()) {
320 if (!ParseAttrDefaultValue(attr_ptr->type(), default_value, &attr_obj)) {
321 MS_LOG(EXCEPTION) << "op [ " << op_info->op_name() << " ]'s default attr [ " << attr_name
322 << " ] generates failed";
323 }
324 attr_obj[kJValid] = true;
325 } else {
326 MS_LOG(INFO) << "op " << op_name << "'s attr \"" << attr_name << "\" should have a default value.";
327 if (!op_info->impl_path().empty() && attr_ptr->param_type() == kJParamRequred) {
328 MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name
329 << " is required, but not set.";
330 } else {
331 attr_obj[kJValid] = false;
332 }
333 }
334 }
335 (*attrs_json).push_back(attr_obj);
336 }
337
338 if (!AttrsJsonPostProcessing(anf_node, op_info, attrs_json)) {
339 MS_LOG(EXCEPTION) << "PostProcessing node attr error, node: " << anf_node->fullname_with_scope();
340 }
341 }
342
GenAttrsDescJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)343 void TbeJsonCreator::GenAttrsDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
344 MS_EXCEPTION_IF_NULL(anf_node);
345 MS_EXCEPTION_IF_NULL(compute_json);
346 auto cnode = anf_node->cast<CNodePtr>();
347 MS_EXCEPTION_IF_NULL(cnode);
348 auto op_name = AnfAlgo::GetCNodeName(cnode);
349 auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
350 nlohmann::json attrs_json;
351 GenAttrsJson(cnode, op_info_ptr, &attrs_json);
352 if (!attrs_json.empty()) {
353 (*compute_json)[kJAttrs] = attrs_json;
354 }
355
356 nlohmann::json attrs_desc;
357 for (const auto &attr : attrs_json) {
358 if (GetJsonValue<std::string>(attr, kJName) != kJIsRef && GetJsonValue<bool>(attr, kJValid)) {
359 attrs_desc.push_back(attr.at(kJValue));
360 }
361 }
362 if (!attrs_desc.empty()) {
363 (*compute_json)[kJAttrDesc] = attrs_desc;
364 }
365 }
366
GenComputeCommonJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)367 void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
368 MS_EXCEPTION_IF_NULL(anf_node);
369 MS_EXCEPTION_IF_NULL(compute_json);
370 auto cnode = anf_node->cast<CNodePtr>();
371 MS_EXCEPTION_IF_NULL(cnode);
372 auto op_name = AnfAlgo::GetCNodeName(cnode);
373 auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
374 auto func_name = op_info_ptr->kernel_name();
375 (*compute_json)[kJFuncName] = func_name;
376 auto python_module_path = op_info_ptr->impl_path();
377 if (python_module_path.empty()) {
378 python_module_path = kPyPath;
379 }
380
381 auto iter = tbe::opTypeAdapter.find(op_name);
382 (*compute_json)[kJType] = (iter != tbe::opTypeAdapter.end()) ? iter->second : op_name;
383 (*compute_json)[kJPyModulePath] = python_module_path;
384 (*compute_json)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
385 (*compute_json)[kJInt64Mode] = false;
386 (*compute_json)[kJName] = cnode->fullname_with_scope();
387 (*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));
388 (*compute_json)[kJModuleName] = kJModuleNamePrefix + func_name;
389 }
390
391 // node_out_idx: node output index
392 // desc_output_idx: this index use to add json
GenDescJson(const AnfNodePtr & anf_node,size_t node_out_idx,size_t desc_output_idx,nlohmann::json * output_desc)393 void TbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
394 nlohmann::json *output_desc) {
395 MS_EXCEPTION_IF_NULL(anf_node);
396 GenDesJsonCommon(output_desc);
397 std::vector<int64_t> shape;
398 std::vector<int64_t> ori_shape;
399 shape = TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(anf_node, node_out_idx);
400 ori_shape = TbeJsonUtils::GetOutputOriShapeForTbeBuild(anf_node, node_out_idx);
401 if (shape.empty()) {
402 shape.emplace_back(1);
403 }
404 if (ori_shape.empty()) {
405 ori_shape.emplace_back(1);
406 }
407
408 auto full_name = anf_node->fullname_with_scope();
409 auto output_desc_name = node_out_idx > 0 ? (full_name + "_" + std::to_string(node_out_idx)) : full_name;
410
411 // !! Note: format: only data node's output use it
412 auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
413 format = tbe::TbeAdapter::FormatPass(format, ori_shape.size());
414 auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
415 format =
416 (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) ? kOpFormat_NCDHW : format;
417
418 (*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
419 (*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
420 (*output_desc)[kJFormat] = format;
421 (*output_desc)[kJOriFormat] = def_format;
422 (*output_desc)[kJOriShape] = ori_shape;
423 (*output_desc)[kJShape] = shape;
424 (*output_desc)[kJName] = output_desc_name;
425 // !! Note: output_index, only node's output use it
426 (*output_desc)[kJOutputIndex] = desc_output_idx;
427 }
428
GenDesJsonCommon(nlohmann::json * output_desc)429 void TbeJsonCreator::GenDesJsonCommon(nlohmann::json *output_desc) {
430 MS_EXCEPTION_IF_NULL(output_desc);
431 (*output_desc)[kJL1AddrOffset] = 0;
432 (*output_desc)[kJL1FusionType] = -1;
433 (*output_desc)[kJL1WorkspaceSize] = -1;
434 (*output_desc)[kJAddrType] = 0;
435 (*output_desc)[kJSliceOffset] = nlohmann::json::array();
436 (*output_desc)[kJSplitIndex] = 0;
437 (*output_desc)[kJTotalShape] = nlohmann::json::array();
438 (*output_desc)[kJValidShape] = nlohmann::json::array();
439 }
440
ParseConstValue(const mindspore::ValuePtr & value,nlohmann::json * json_obj)441 void ParseConstValue(const mindspore::ValuePtr &value, nlohmann::json *json_obj) {
442 if (value->isa<tensor::Tensor>()) {
443 auto tensor = value->cast<tensor::TensorPtr>();
444 MS_EXCEPTION_IF_NULL(tensor);
445 TypePtr data_type = tensor->Dtype();
446 MS_EXCEPTION_IF_NULL(data_type);
447 TypeId type_id = data_type->type_id();
448 (*json_obj)[kJConstValueDtype] = tbe::TypeIdToString(type_id);
449 switch (type_id) {
450 case kNumberTypeInt8:
451 (*json_obj)[kJConstValue] = TensorValueToVector<int8_t>(tensor);
452 break;
453
454 case kNumberTypeUInt8:
455 (*json_obj)[kJConstValue] = TensorValueToVector<uint8_t>(tensor);
456 break;
457
458 case kNumberTypeInt16:
459 (*json_obj)[kJConstValue] = TensorValueToVector<int16_t>(tensor);
460 break;
461
462 case kNumberTypeUInt16:
463 (*json_obj)[kJConstValue] = TensorValueToVector<uint16_t>(tensor);
464 break;
465
466 case kNumberTypeInt32:
467 (*json_obj)[kJConstValue] = TensorValueToVector<int32_t>(tensor);
468 break;
469
470 case kNumberTypeUInt32:
471 (*json_obj)[kJConstValue] = TensorValueToVector<uint32_t>(tensor);
472 break;
473
474 case kNumberTypeInt64:
475 (*json_obj)[kJConstValue] = TensorValueToVector<int64_t>(tensor);
476 break;
477
478 case kNumberTypeUInt64:
479 (*json_obj)[kJConstValue] = TensorValueToVector<uint64_t>(tensor);
480 break;
481
482 case kNumberTypeFloat32:
483 (*json_obj)[kJConstValue] = TensorValueToVector<float>(tensor);
484 break;
485
486 case kNumberTypeFloat64:
487 (*json_obj)[kJConstValue] = TensorValueToVector<double>(tensor);
488 break;
489
490 default:
491 MS_LOG(EXCEPTION) << "When parse const input value, the value data type: " << data_type << " is not supported.";
492 }
493 } else {
494 MS_LOG(WARNING) << "Const value input is not a tensor.";
495 }
496 }
497
GenInputConstValue(const AnfNodePtr & anf_node,size_t real_input_index,nlohmann::json * input_desc)498 void TbeJsonCreator::GenInputConstValue(const AnfNodePtr &anf_node, size_t real_input_index,
499 nlohmann::json *input_desc) {
500 MS_EXCEPTION_IF_NULL(anf_node);
501 MS_EXCEPTION_IF_NULL(input_desc);
502 auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
503 MS_EXCEPTION_IF_NULL(kernel_info);
504 auto build_info = kernel_info->select_kernel_build_info();
505 MS_EXCEPTION_IF_NULL(build_info);
506 auto value_depend = build_info->GetInputValueDepend(real_input_index);
507 if (value_depend.empty() || value_depend == kIgnored) {
508 return;
509 }
510 auto cnode = anf_node->cast<CNodePtr>();
511 MS_EXCEPTION_IF_NULL(cnode);
512 auto input_node = cnode->inputs()[real_input_index + 1];
513 if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
514 input_node = AnfAlgo::VisitKernel(input_node, 0).first;
515 }
516 MS_EXCEPTION_IF_NULL(input_node);
517 if (input_node->isa<ValueNode>()) {
518 MS_LOG(INFO) << "Const Input value node info : " << GetValueNode(input_node)->ToString();
519 auto value_node = input_node->cast<ValueNodePtr>();
520 MS_EXCEPTION_IF_NULL(value_node);
521 auto value = value_node->value();
522 MS_EXCEPTION_IF_NULL(value);
523 ParseConstValue(value, input_desc);
524 } else {
525 MS_LOG(ERROR) << "The operator " << anf_node->fullname_with_scope() << "'s input" << real_input_index
526 << "'s value depend is " << value_depend << ", but its input node is a " << input_node->type_name()
527 << ", not a value node.";
528 }
529 }
530
AttrsJsonPreProcessing(const AnfNodePtr & anf_node,std::vector<OpAttrPtr> * attrs_ptr,nlohmann::json * attrs_json)531 bool TbeJsonCreator::AttrsJsonPreProcessing(const AnfNodePtr &anf_node, std::vector<OpAttrPtr> *attrs_ptr,
532 nlohmann::json *attrs_json) {
533 tbe::TbeAdapter::CastAttrJsonPrePass(anf_node, attrs_ptr, attrs_json);
534 return true;
535 }
GenOutputDataDescJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)536 void TbeJsonCreator::GenOutputDataDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
537 MS_EXCEPTION_IF_NULL(anf_node);
538 MS_EXCEPTION_IF_NULL(compute_json);
539 auto op_desc = AnfAlgo::GetOutputDataDesc(anf_node);
540 // get output_data_desc from prebuild
541 if (!op_desc.empty() && op_desc.at(0).find(kJListArgs) != op_desc.at(0).end()) {
542 (*compute_json)[kJOutputDataDesc] = GetJsonValue<nlohmann::json>(op_desc.at(0), kJListArgs);
543 } else {
544 auto outputs_desc = GetJsonValue<std::vector<nlohmann::json>>(*compute_json, kJOutputDesc);
545 std::vector<nlohmann::json> outputs_data_desc;
546 for (auto output_desc : outputs_desc) {
547 if (output_desc.find(kJOriShape) != output_desc.end()) {
548 output_desc.erase(kJName);
549 outputs_data_desc.push_back(output_desc);
550 }
551 }
552 (*compute_json)[kJOutputDataDesc] = outputs_data_desc;
553 }
554 }
555
AttrsJsonPostProcessing(const AnfNodePtr & anf_node,const OpInfoPtr & op_info_ptr,nlohmann::json * attrs_json)556 bool TbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
557 nlohmann::json *attrs_json) {
558 return true;
559 }
560 } // namespace mindspore::kernel
561