• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h"
18 #include <memory>
19 #include <map>
20 #include <utility>
21 #include <algorithm>
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "backend/kernel_compiler/common_utils.h"
24 #include "backend/kernel_compiler/tbe/tbe_adapter.h"
25 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
26 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
27 #include "utils/ms_context.h"
28 #include "runtime/dev.h"
29 #include "utils/ms_utils.h"
30 #include "utils/json_operation_utils.h"
31 #include "utils/convert_utils.h"
32 #include "backend/kernel_compiler/tbe/tbe_json/tbe_json_utils.h"
33 
34 namespace mindspore::kernel {
35 namespace {
36 static std::unordered_map<std::string, ATTR_DTYPE> type_attr_dtype_map = {
37   {kVTypeInt, ATTR_DTYPE::ATTR_INT32},
38   {kVTypeInt64, ATTR_DTYPE::ATTR_INT64},
39   {kVTypeStr, ATTR_DTYPE::ATTR_STR},
40   {kVTypeBool, ATTR_DTYPE::ATTR_BOOL},
41   {kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32},
42   {kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32},
43   {kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32},
44   {kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64},
45   {kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}};
46 
47 static std::map<ATTR_DTYPE, std::string> tbe_attr_dtype_to_string_map = {
48   {ATTR_INT8, "int8"},
49   {ATTR_UINT8, "uint8"},
50   {ATTR_INT16, "int16"},
51   {ATTR_UINT16, "uint16"},
52   {ATTR_INT32, "int32"},
53   {ATTR_UINT32, "uint32"},
54   {ATTR_INT64, "int64"},
55   {ATTR_UINT64, "uint64"},
56   {ATTR_FLOAT32, "float32"},
57   {ATTR_DOUBLE, "double"},
58   {ATTR_BOOL, "bool"},
59   {ATTR_STR, "str"},
60   {ATTR_LIST_INT8, "list_int8"},
61   {ATTR_LIST_UINT8, "list_uint8"},
62   {ATTR_LIST_INT16, "list_int16"},
63   {ATTR_LIST_UINT16, "list_uint16"},
64   {ATTR_LIST_INT32, "list_int32"},
65   {ATTR_LIST_UINT32, "list_uint32"},
66   {ATTR_LIST_INT64, "list_int64"},
67   {ATTR_LIST_UINT64, "list_uint64"},
68   {ATTR_LIST_FLOAT32, "list_float32"},
69   {ATTR_LIST_DOUBLE, "list_double"},
70   {ATTR_LIST_BOOL, "list_bool"},
71   {ATTR_LIST_STR, "list_str"},
72   {ATTR_LIST_LIST_INT64, "list_list_int64"},
73   {ATTR_LIST_LIST_FLOAT, "list_list_float"},
74 };
75 
ParseListIntValue(const mindspore::ValuePtr & value,std::vector<int64_t> * attr_value)76 bool ParseListIntValue(const mindspore::ValuePtr &value, std::vector<int64_t> *attr_value) {
77   auto value_type = value->type();
78   if (value_type == nullptr) {
79     MS_LOG(ERROR) << "Value's type is null.";
80     return false;
81   }
82   if (value_type->ToString() == kVTypeInt64) {
83     attr_value->push_back(GetValue<int64_t>(value));
84   } else {
85     auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
86     if (!vec.empty()) {
87       if (vec[0]->isa<Int32Imm>()) {
88         std::vector<int32_t> attr_value_me = GetValue<std::vector<int32_t>>(value);
89         (void)std::transform(attr_value_me.begin(), attr_value_me.end(), std::back_inserter(*attr_value),
90                              [](const int &value) { return static_cast<int64_t>(value); });
91       } else {
92         *attr_value = GetValue<std::vector<int64_t>>(value);
93       }
94     }
95   }
96   return true;
97 }
98 
ParseAttrValue(const std::string & type,const mindspore::ValuePtr & value,nlohmann::json * attr_obj)99 bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, nlohmann::json *attr_obj) {
100   MS_EXCEPTION_IF_NULL(attr_obj);
101   if (value == nullptr) {
102     MS_LOG(ERROR) << "Node's attr value is null.";
103     return false;
104   }
105   auto result = type_attr_dtype_map.find(type);
106   if (result == type_attr_dtype_map.end()) {
107     MS_LOG(ERROR) << "Type: " << type << "not support";
108     return false;
109   }
110 
111   auto dtype_string = tbe_attr_dtype_to_string_map.find(result->second);
112   if (dtype_string == tbe_attr_dtype_to_string_map.end()) {
113     MS_LOG(ERROR) << "Can't convert attr dtype " << result->second << " to string";
114     return false;
115   }
116   (*attr_obj)[kJDtype] = dtype_string->second;
117 
118   switch (result->second) {
119     case ATTR_DTYPE::ATTR_INT32:
120       (*attr_obj)[kJValue] = value->isa<Int32Imm>() ? GetValue<int>(value) : GetValue<int64_t>(value);
121       break;
122     case ATTR_DTYPE::ATTR_INT64:
123       (*attr_obj)[kJValue] = GetValue<int64_t>(value);
124       break;
125     case ATTR_DTYPE::ATTR_STR: {
126       auto attr_value = GetValue<std::string>(value);
127       (*attr_obj)[kJValue] = attr_value == kOpFormat_FRAC_Z ? kJOpFormat_FRACTAL_Z : attr_value;
128       break;
129     }
130     case ATTR_DTYPE::ATTR_BOOL:
131       (*attr_obj)[kJValue] = GetValue<bool>(value);
132       break;
133     case ATTR_DTYPE::ATTR_FLOAT32:
134       (*attr_obj)[kJValue] = GetValue<float>(value);
135       break;
136     case ATTR_DTYPE::ATTR_LIST_INT32: {
137       std::vector<int64_t> attr_value;
138       if (!ParseListIntValue(value, &attr_value)) {
139         MS_LOG(ERROR) << "Parse list_value failed, maybe the input is a nullptr.";
140         return false;
141       }
142       (*attr_obj)[kJValue] = attr_value;
143       break;
144     }
145     case ATTR_DTYPE::ATTR_LIST_FLOAT32: {
146       auto value_type = value->type();
147       if (value_type == nullptr) {
148         MS_LOG(ERROR) << "Value's type is null.";
149         return false;
150       }
151       (*attr_obj)[kJValue] = value_type->ToString() == kVTypeFloat ? std::vector<float>{GetValue<float>(value)}
152                                                                    : GetValue<std::vector<float>>(value);
153       break;
154     }
155     case ATTR_DTYPE::ATTR_LIST_UINT64:
156       (*attr_obj)[kJValue] = GetValue<std::vector<size_t>>(value);
157       break;
158     case ATTR_DTYPE::ATTR_LIST_LIST_INT64:
159       (*attr_obj)[kJValue] = GetValue<std::vector<std::vector<int64_t>>>(value);
160       break;
161 
162     default:
163       MS_LOG(ERROR) << "Type: " << type << "not support";
164       return false;
165   }
166   return true;
167 }
168 
ParseAttrDefaultValue(const std::string & type,const std::string & value,nlohmann::json * attr_obj)169 bool ParseAttrDefaultValue(const std::string &type, const std::string &value, nlohmann::json *attr_obj) {
170   MS_EXCEPTION_IF_NULL(attr_obj);
171   auto result = type_attr_dtype_map.find(type);
172   if (result == type_attr_dtype_map.end()) {
173     MS_LOG(ERROR) << "Type: " << type << "not support";
174     return false;
175   }
176 
177   auto dtype_string = tbe_attr_dtype_to_string_map.find(result->second);
178   if (dtype_string == tbe_attr_dtype_to_string_map.end()) {
179     MS_LOG(ERROR) << "Can't convert attr dtype " << result->second << " to string";
180     return false;
181   }
182   (*attr_obj)[kJDtype] = dtype_string->second;
183 
184   switch (result->second) {
185     case ATTR_DTYPE::ATTR_INT32:
186       (*attr_obj)[kJValue] = std::stoi(value);
187       break;
188     case ATTR_DTYPE::ATTR_INT64:
189       (*attr_obj)[kJValue] = std::stoll(value);
190       break;
191     case ATTR_DTYPE::ATTR_STR:
192       (*attr_obj)[kJValue] = value;
193       break;
194     case ATTR_DTYPE::ATTR_BOOL: {
195       bool attr_value = false;
196       std::istringstream(value) >> std::boolalpha >> attr_value;
197       (*attr_obj)[kJValue] = attr_value;
198       break;
199     }
200     case ATTR_DTYPE::ATTR_FLOAT32:
201       (*attr_obj)[kJValue] = std::stof(value);
202       break;
203     case ATTR_DTYPE::ATTR_LIST_INT32: {
204       std::stringstream string_value(value);
205       std::string list_elem;
206       std::vector<int64_t> attrs_value;
207       while (std::getline(string_value, list_elem, ',')) {
208         attrs_value.push_back(std::stoi(list_elem));
209       }
210       (*attr_obj)[kJValue] = attrs_value;
211       break;
212     }
213     default:
214       MS_LOG(ERROR) << "Type: " << type << "not support";
215       return false;
216   }
217   return true;
218 }
219 }  // namespace
220 
GenComputeJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)221 bool TbeJsonCreator::GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
222   MS_EXCEPTION_IF_NULL(anf_node);
223   MS_EXCEPTION_IF_NULL(compute_json);
224   MS_LOG(DEBUG) << "Start.";
225 
226   if (!GenInputsJson(anf_node, compute_json)) {
227     MS_LOG(ERROR) << "generate inputs json failed, node full name:" << anf_node->fullname_with_scope();
228     return false;
229   }
230   if (!GenOutputsJson(anf_node, compute_json)) {
231     MS_LOG(ERROR) << "generate outputs json failed, node full name:" << anf_node->fullname_with_scope();
232     return false;
233   }
234   GenOutputDataDescJson(anf_node, compute_json);
235   GenAttrsDescJson(anf_node, compute_json);
236   GenComputeCommonJson(anf_node, compute_json);
237   GenOtherJson(anf_node, compute_json);
238   MS_LOG(DEBUG) << "End.";
239   return true;
240 }
241 
GenFusionOpName(nlohmann::json * kernel_json,std::string prefix)242 void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string prefix) {
243   json_name_.clear();
244   json_hash_ = GenJsonHash((*kernel_json));
245   auto context_ptr = MsContext::GetInstance();
246   MS_EXCEPTION_IF_NULL(context_ptr);
247   json_name_ = std::move(prefix);
248   auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
249   for (auto node_json : (*kernel_json)[kJOpList]) {
250     if (GetJsonValue<std::string>(node_json, kJType) != kJData) {
251       json_name_.append(node_json[kJFuncName]);
252       json_name_.append("_");
253     }
254   }
255   json_name_ = json_name_ + std::to_string(json_hash_) + "_" + std::to_string(device_id);
256   MS_LOG(DEBUG) << "Generate Json name: " << json_name_;
257   (*kernel_json)[kJFusionOpName] = json_name_;
258 }
259 
DeleteDescName(nlohmann::json * desc_jsons)260 void TbeJsonCreator::DeleteDescName(nlohmann::json *desc_jsons) {
261   for (auto &desc_json : (*desc_jsons)) {
262     if (desc_json.is_array()) {
263       for (auto &desc_item : desc_json) {
264         desc_item.erase(kJName);
265       }
266     } else {
267       desc_json.erase(kJName);
268     }
269   }
270 }
271 
GenJsonHash(nlohmann::json tbe_json)272 size_t TbeJsonCreator::GenJsonHash(nlohmann::json tbe_json) {
273   auto &op_lists = tbe_json.at(kJOpList);
274   for (auto &op : op_lists) {
275     op.erase(kJName);
276     op.erase(kJOriName);
277     op.erase(kJPattern);
278     DeleteDescName(&op.at(kJOutputDesc));
279     if (op[kJType] != kJData) {
280       DeleteDescName(&op.at(kJInputDesc));
281     }
282   }
283   return std::hash<std::string>()(op_lists.dump());
284 }
285 
AddOpNameForComputeNode(nlohmann::json * kernel_json)286 void TbeJsonCreator::AddOpNameForComputeNode(nlohmann::json *kernel_json) {
287   auto op_name = GetJsonValue<std::string>((*kernel_json), kJFusionOpName);
288   for (auto &node_json : (*kernel_json).at(kJOpList)) {
289     // compute node
290     if (GetJsonValue<std::string>(node_json, kJType) != kJData) {
291       node_json[kJOpName] = op_name;
292     }
293   }
294 }
295 
GenAttrsJson(const AnfNodePtr & anf_node,const OpInfoPtr & op_info,nlohmann::json * attrs_json)296 void TbeJsonCreator::GenAttrsJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json) {
297   MS_EXCEPTION_IF_NULL(anf_node);
298   MS_EXCEPTION_IF_NULL(op_info);
299   MS_EXCEPTION_IF_NULL(attrs_json);
300   auto attrs_ptr = op_info->attrs_ptr();
301   if (!AttrsJsonPreProcessing(anf_node, &attrs_ptr, attrs_json)) {
302     MS_LOG(EXCEPTION) << "PreProcessing node attr error, node: " << anf_node->fullname_with_scope();
303   }
304 
305   std::string op_name = AnfAlgo::GetCNodeName(anf_node);
306   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
307   MS_EXCEPTION_IF_NULL(primitive);
308   for (const auto &attr_ptr : attrs_ptr) {
309     std::string attr_name = attr_ptr->name();
310     nlohmann::json attr_obj;
311     attr_obj[kJName] = attr_name;
312     if (primitive->GetAttr(attr_name) != nullptr) {
313       if (!ParseAttrValue(attr_ptr->type(), primitive->GetAttr(attr_name), &attr_obj)) {
314         MS_LOG(EXCEPTION) << "op [ " << op_info->op_name() << " ]'s attr [ " << attr_name << " ] generates failed";
315       }
316       attr_obj[kJValid] = true;
317     } else {
318       auto default_value = attr_ptr->default_value();
319       if (!default_value.empty()) {
320         if (!ParseAttrDefaultValue(attr_ptr->type(), default_value, &attr_obj)) {
321           MS_LOG(EXCEPTION) << "op [ " << op_info->op_name() << " ]'s default attr [ " << attr_name
322                             << " ] generates failed";
323         }
324         attr_obj[kJValid] = true;
325       } else {
326         MS_LOG(INFO) << "op " << op_name << "'s attr \"" << attr_name << "\" should have a default value.";
327         if (!op_info->impl_path().empty() && attr_ptr->param_type() == kJParamRequred) {
328           MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name
329                             << " is required, but not set.";
330         } else {
331           attr_obj[kJValid] = false;
332         }
333       }
334     }
335     (*attrs_json).push_back(attr_obj);
336   }
337 
338   if (!AttrsJsonPostProcessing(anf_node, op_info, attrs_json)) {
339     MS_LOG(EXCEPTION) << "PostProcessing node attr error, node: " << anf_node->fullname_with_scope();
340   }
341 }
342 
GenAttrsDescJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)343 void TbeJsonCreator::GenAttrsDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
344   MS_EXCEPTION_IF_NULL(anf_node);
345   MS_EXCEPTION_IF_NULL(compute_json);
346   auto cnode = anf_node->cast<CNodePtr>();
347   MS_EXCEPTION_IF_NULL(cnode);
348   auto op_name = AnfAlgo::GetCNodeName(cnode);
349   auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
350   nlohmann::json attrs_json;
351   GenAttrsJson(cnode, op_info_ptr, &attrs_json);
352   if (!attrs_json.empty()) {
353     (*compute_json)[kJAttrs] = attrs_json;
354   }
355 
356   nlohmann::json attrs_desc;
357   for (const auto &attr : attrs_json) {
358     if (GetJsonValue<std::string>(attr, kJName) != kJIsRef && GetJsonValue<bool>(attr, kJValid)) {
359       attrs_desc.push_back(attr.at(kJValue));
360     }
361   }
362   if (!attrs_desc.empty()) {
363     (*compute_json)[kJAttrDesc] = attrs_desc;
364   }
365 }
366 
GenComputeCommonJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)367 void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
368   MS_EXCEPTION_IF_NULL(anf_node);
369   MS_EXCEPTION_IF_NULL(compute_json);
370   auto cnode = anf_node->cast<CNodePtr>();
371   MS_EXCEPTION_IF_NULL(cnode);
372   auto op_name = AnfAlgo::GetCNodeName(cnode);
373   auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
374   auto func_name = op_info_ptr->kernel_name();
375   (*compute_json)[kJFuncName] = func_name;
376   auto python_module_path = op_info_ptr->impl_path();
377   if (python_module_path.empty()) {
378     python_module_path = kPyPath;
379   }
380 
381   auto iter = tbe::opTypeAdapter.find(op_name);
382   (*compute_json)[kJType] = (iter != tbe::opTypeAdapter.end()) ? iter->second : op_name;
383   (*compute_json)[kJPyModulePath] = python_module_path;
384   (*compute_json)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
385   (*compute_json)[kJInt64Mode] = false;
386   (*compute_json)[kJName] = cnode->fullname_with_scope();
387   (*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));
388   (*compute_json)[kJModuleName] = kJModuleNamePrefix + func_name;
389 }
390 
391 // node_out_idx: node output index
392 // desc_output_idx: this index use to add json
GenDescJson(const AnfNodePtr & anf_node,size_t node_out_idx,size_t desc_output_idx,nlohmann::json * output_desc)393 void TbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
394                                  nlohmann::json *output_desc) {
395   MS_EXCEPTION_IF_NULL(anf_node);
396   GenDesJsonCommon(output_desc);
397   std::vector<int64_t> shape;
398   std::vector<int64_t> ori_shape;
399   shape = TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(anf_node, node_out_idx);
400   ori_shape = TbeJsonUtils::GetOutputOriShapeForTbeBuild(anf_node, node_out_idx);
401   if (shape.empty()) {
402     shape.emplace_back(1);
403   }
404   if (ori_shape.empty()) {
405     ori_shape.emplace_back(1);
406   }
407 
408   auto full_name = anf_node->fullname_with_scope();
409   auto output_desc_name = node_out_idx > 0 ? (full_name + "_" + std::to_string(node_out_idx)) : full_name;
410 
411   // !! Note: format: only data node's output use it
412   auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
413   format = tbe::TbeAdapter::FormatPass(format, ori_shape.size());
414   auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
415   format =
416     (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) ? kOpFormat_NCDHW : format;
417 
418   (*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
419   (*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
420   (*output_desc)[kJFormat] = format;
421   (*output_desc)[kJOriFormat] = def_format;
422   (*output_desc)[kJOriShape] = ori_shape;
423   (*output_desc)[kJShape] = shape;
424   (*output_desc)[kJName] = output_desc_name;
425   // !! Note: output_index, only node's output use it
426   (*output_desc)[kJOutputIndex] = desc_output_idx;
427 }
428 
GenDesJsonCommon(nlohmann::json * output_desc)429 void TbeJsonCreator::GenDesJsonCommon(nlohmann::json *output_desc) {
430   MS_EXCEPTION_IF_NULL(output_desc);
431   (*output_desc)[kJL1AddrOffset] = 0;
432   (*output_desc)[kJL1FusionType] = -1;
433   (*output_desc)[kJL1WorkspaceSize] = -1;
434   (*output_desc)[kJAddrType] = 0;
435   (*output_desc)[kJSliceOffset] = nlohmann::json::array();
436   (*output_desc)[kJSplitIndex] = 0;
437   (*output_desc)[kJTotalShape] = nlohmann::json::array();
438   (*output_desc)[kJValidShape] = nlohmann::json::array();
439 }
440 
ParseConstValue(const mindspore::ValuePtr & value,nlohmann::json * json_obj)441 void ParseConstValue(const mindspore::ValuePtr &value, nlohmann::json *json_obj) {
442   if (value->isa<tensor::Tensor>()) {
443     auto tensor = value->cast<tensor::TensorPtr>();
444     MS_EXCEPTION_IF_NULL(tensor);
445     TypePtr data_type = tensor->Dtype();
446     MS_EXCEPTION_IF_NULL(data_type);
447     TypeId type_id = data_type->type_id();
448     (*json_obj)[kJConstValueDtype] = tbe::TypeIdToString(type_id);
449     switch (type_id) {
450       case kNumberTypeInt8:
451         (*json_obj)[kJConstValue] = TensorValueToVector<int8_t>(tensor);
452         break;
453 
454       case kNumberTypeUInt8:
455         (*json_obj)[kJConstValue] = TensorValueToVector<uint8_t>(tensor);
456         break;
457 
458       case kNumberTypeInt16:
459         (*json_obj)[kJConstValue] = TensorValueToVector<int16_t>(tensor);
460         break;
461 
462       case kNumberTypeUInt16:
463         (*json_obj)[kJConstValue] = TensorValueToVector<uint16_t>(tensor);
464         break;
465 
466       case kNumberTypeInt32:
467         (*json_obj)[kJConstValue] = TensorValueToVector<int32_t>(tensor);
468         break;
469 
470       case kNumberTypeUInt32:
471         (*json_obj)[kJConstValue] = TensorValueToVector<uint32_t>(tensor);
472         break;
473 
474       case kNumberTypeInt64:
475         (*json_obj)[kJConstValue] = TensorValueToVector<int64_t>(tensor);
476         break;
477 
478       case kNumberTypeUInt64:
479         (*json_obj)[kJConstValue] = TensorValueToVector<uint64_t>(tensor);
480         break;
481 
482       case kNumberTypeFloat32:
483         (*json_obj)[kJConstValue] = TensorValueToVector<float>(tensor);
484         break;
485 
486       case kNumberTypeFloat64:
487         (*json_obj)[kJConstValue] = TensorValueToVector<double>(tensor);
488         break;
489 
490       default:
491         MS_LOG(EXCEPTION) << "When parse const input value, the value data type: " << data_type << " is not supported.";
492     }
493   } else {
494     MS_LOG(WARNING) << "Const value input is not a tensor.";
495   }
496 }
497 
GenInputConstValue(const AnfNodePtr & anf_node,size_t real_input_index,nlohmann::json * input_desc)498 void TbeJsonCreator::GenInputConstValue(const AnfNodePtr &anf_node, size_t real_input_index,
499                                         nlohmann::json *input_desc) {
500   MS_EXCEPTION_IF_NULL(anf_node);
501   MS_EXCEPTION_IF_NULL(input_desc);
502   auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
503   MS_EXCEPTION_IF_NULL(kernel_info);
504   auto build_info = kernel_info->select_kernel_build_info();
505   MS_EXCEPTION_IF_NULL(build_info);
506   auto value_depend = build_info->GetInputValueDepend(real_input_index);
507   if (value_depend.empty() || value_depend == kIgnored) {
508     return;
509   }
510   auto cnode = anf_node->cast<CNodePtr>();
511   MS_EXCEPTION_IF_NULL(cnode);
512   auto input_node = cnode->inputs()[real_input_index + 1];
513   if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
514     input_node = AnfAlgo::VisitKernel(input_node, 0).first;
515   }
516   MS_EXCEPTION_IF_NULL(input_node);
517   if (input_node->isa<ValueNode>()) {
518     MS_LOG(INFO) << "Const Input value node info : " << GetValueNode(input_node)->ToString();
519     auto value_node = input_node->cast<ValueNodePtr>();
520     MS_EXCEPTION_IF_NULL(value_node);
521     auto value = value_node->value();
522     MS_EXCEPTION_IF_NULL(value);
523     ParseConstValue(value, input_desc);
524   } else {
525     MS_LOG(ERROR) << "The operator " << anf_node->fullname_with_scope() << "'s input" << real_input_index
526                   << "'s value depend is " << value_depend << ", but its input node is a " << input_node->type_name()
527                   << ", not a value node.";
528   }
529 }
530 
AttrsJsonPreProcessing(const AnfNodePtr & anf_node,std::vector<OpAttrPtr> * attrs_ptr,nlohmann::json * attrs_json)531 bool TbeJsonCreator::AttrsJsonPreProcessing(const AnfNodePtr &anf_node, std::vector<OpAttrPtr> *attrs_ptr,
532                                             nlohmann::json *attrs_json) {
533   tbe::TbeAdapter::CastAttrJsonPrePass(anf_node, attrs_ptr, attrs_json);
534   return true;
535 }
GenOutputDataDescJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)536 void TbeJsonCreator::GenOutputDataDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
537   MS_EXCEPTION_IF_NULL(anf_node);
538   MS_EXCEPTION_IF_NULL(compute_json);
539   auto op_desc = AnfAlgo::GetOutputDataDesc(anf_node);
540   // get output_data_desc from prebuild
541   if (!op_desc.empty() && op_desc.at(0).find(kJListArgs) != op_desc.at(0).end()) {
542     (*compute_json)[kJOutputDataDesc] = GetJsonValue<nlohmann::json>(op_desc.at(0), kJListArgs);
543   } else {
544     auto outputs_desc = GetJsonValue<std::vector<nlohmann::json>>(*compute_json, kJOutputDesc);
545     std::vector<nlohmann::json> outputs_data_desc;
546     for (auto output_desc : outputs_desc) {
547       if (output_desc.find(kJOriShape) != output_desc.end()) {
548         output_desc.erase(kJName);
549         outputs_data_desc.push_back(output_desc);
550       }
551     }
552     (*compute_json)[kJOutputDataDesc] = outputs_data_desc;
553   }
554 }
555 
AttrsJsonPostProcessing(const AnfNodePtr & anf_node,const OpInfoPtr & op_info_ptr,nlohmann::json * attrs_json)556 bool TbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
557                                              nlohmann::json *attrs_json) {
558   return true;
559 }
560 }  // namespace mindspore::kernel
561