• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
18 
19 #include <set>
20 #include <functional>
21 #include <algorithm>
22 #include "abstract/dshape.h"
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "ir/func_graph.h"
27 #include "utils/anf_utils.h"
28 #include "utils/ms_context.h"
29 #include "backend/common/graph_kernel/core/graph_builder.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
31 #include "backend/common/graph_kernel/graph_kernel_flags.h"
32 #include "kernel/graph_kernel/graph_kernel_json_flags.h"
33 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
34 #ifdef ENABLE_GPU
35 #include <cuda.h>
36 #endif
37 #else
38 #include "kernel/oplib/oplib.h"
39 #include "runtime/hardware/device_context_manager.h"
40 #endif
41 
42 namespace mindspore::graphkernel {
43 using kernel::OpAttr;
44 using kernel::OpImplyType;
45 using kernel::OpInfo;
46 using kernel::OpIOInfo;
47 namespace {
48 constexpr int kCurrentInfoVersion = 2;
49 constexpr auto kAttrParallelDimInfoSize = 2;
50 constexpr auto kDebugStrDepth = 2;
51 
GetDynInputSizes(const AnfNodePtr & anf_node)52 std::vector<int64_t> GetDynInputSizes(const AnfNodePtr &anf_node) {
53   std::vector<int64_t> dyn_input_sizes;
54   auto primitive = GetCNodePrimitive(anf_node);
55   MS_EXCEPTION_IF_NULL(primitive);
56   if (primitive->HasAttr(kAttrDynInputSizes)) {
57     dyn_input_sizes = GetValue<const std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
58   }
59   return dyn_input_sizes;
60 }
61 
GetKernelInput(const AnfNodePtr & anf_node,size_t index)62 std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
63   MS_EXCEPTION_IF_NULL(anf_node);
64   auto inputs_num = AnfUtils::GetInputTensorNum(anf_node);
65   if (index >= inputs_num) {
66     MS_EXCEPTION(ArgumentError) << "Input index " << index << " is out of range [0, " << inputs_num << ") in node ["
67                                 << anf_node->DebugString() << "]";
68   }
69   auto cnode = anf_node->cast<CNodePtr>();
70   if (cnode == nullptr) {
71     return AnfUtils::VisitKernel(anf_node, 0);
72   } else {
73     return AnfUtils::VisitKernel(cnode->input(index + 1), 0);
74   }
75 }
76 
GetInputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list)77 std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
78                                                                             const std::vector<AnfNodePtr> &input_list) {
79   std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
80   for (size_t i = 0; i < input_list.size(); ++i) {
81     auto const &input = input_list[i];
82     MS_EXCEPTION_IF_NULL(input);
83     MS_EXCEPTION_IF_NULL(input->func_graph());
84     auto mng = input->func_graph()->manager();
85     MS_EXCEPTION_IF_NULL(mng);
86     const NodeUsersMap &users = mng->node_users();
87     auto input_users = users.find(input);
88     if (input_users == users.end() || input_users->second.empty()) {
89       MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(kDebugStrDepth) << "] of ["
90                                   << input->func_graph()->ToString() << "] has no users.";
91     }
92     bool found = false;
93     for (auto const &input_user : input_users->second) {
94       for (auto const &anf_node : node_list) {
95         if (anf_node != input_user.first) {
96           continue;
97         }
98         auto dyn_input_sizes = GetDynInputSizes(anf_node);
99         if (dyn_input_sizes.empty()) {
100           input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
101           found = true;
102           break;
103         }
104         int64_t used_as_idx = IntToLong(input_user.second - 1);
105         int64_t accum_idx = 0;
106         for (size_t dyn_i = 0; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
107           accum_idx += dyn_input_sizes[dyn_i];
108           if (used_as_idx < accum_idx) {
109             auto tmp_dyn_i = dyn_i;  // to evade pclint warning "for statement index variable modified in body."
110             input_index.push_back(std::make_pair(
111               anf_node, std::make_pair(tmp_dyn_i, LongToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
112             found = true;
113             break;
114           }
115         }
116         if (found) {
117           break;
118         }
119       }
120       if (found) {
121         break;
122       }
123     }
124     if (found) {
125       continue;
126     }
127     MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(kDebugStrDepth) << "] of ["
128                                 << input->func_graph()->ToString() << "] found no related kernel info.";
129   }
130   return input_index;
131 }
132 
GetOutputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list)133 std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
134                                                           const std::vector<AnfNodePtr> &input_list,
135                                                           const std::vector<AnfNodePtr> &output_list) {
136   std::vector<std::pair<AnfNodePtr, size_t>> output_index;
137   for (size_t i = 0; i < output_list.size(); ++i) {
138     bool found = false;
139     auto const &output = output_list[i];
140     MS_EXCEPTION_IF_NULL(output);
141     auto pree_node = AnfUtils::VisitKernel(output, 0);
142     auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
143     if (pos != std::end(node_list)) {
144       (void)output_index.emplace_back(pree_node);
145       continue;
146     }
147     auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
148     if (ret != std::end(input_list)) {
149       (void)output_index.emplace_back(std::make_pair(pree_node.first, 0));
150       found = true;
151     }
152     if (!found) {
153       MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(kDebugStrDepth) << "] of ["
154                                   << output->func_graph()->ToString() << "] found no related kernel info.";
155     }
156   }
157   return output_index;
158 }
159 
160 class OpInfoExtractor {
161  public:
162   OpInfoExtractor() = default;
163   ~OpInfoExtractor() = default;
Run(const AnfNodePtr & anf_node)164   OpInfoPtr Run(const AnfNodePtr &anf_node) {
165     MS_EXCEPTION_IF_NULL(anf_node);
166     cnode_ = anf_node->cast<CNodePtr>();
167     MS_EXCEPTION_IF_NULL(cnode_);
168     auto op_info = std::make_shared<OpInfo>();
169     op_info->set_op_name(AnfUtils::GetCNodeName(cnode_));
170     const auto &flags = GraphKernelFlags::GetInstance();
171     if (flags.kernel_generator == "AKG_V2") {
172       op_info->set_imply_type(OpImplyType::kImplyDynamicAKG);
173     } else {
174       op_info->set_imply_type(OpImplyType::kImplyAKG);
175     }
176     ExtractInputs(op_info);
177     ExtractOutputs(op_info);
178     ExtractAttrs(op_info);
179     return op_info;
180   }
181 
182  private:
ExtractInputs(const OpInfoPtr & op_info) const183   void ExtractInputs(const OpInfoPtr &op_info) const {
184     auto dyn_input_sizes = GetDynInputSizes(cnode_);
185     if (dyn_input_sizes.empty()) {
186       for (size_t i = 1; i < cnode_->size(); i++) {
187         auto io_info = std::make_shared<OpIOInfo>();
188         io_info->set_name("input_" + std::to_string(i - 1));
189         op_info->add_inputs_ptr(io_info);
190       }
191     } else {
192       for (size_t i = 0; i < dyn_input_sizes.size(); i++) {
193         auto io_info = std::make_shared<OpIOInfo>();
194         io_info->set_name("input_" + std::to_string(i));
195         io_info->set_param_type("dynamic");
196         op_info->add_inputs_ptr(io_info);
197       }
198     }
199   }
200 
ExtractOutputs(const OpInfoPtr & op_info) const201   void ExtractOutputs(const OpInfoPtr &op_info) const {
202     size_t output_tensor_num = AnfUtils::GetOutputTensorNum(cnode_);
203     for (size_t i = 0; i < output_tensor_num; i++) {
204       auto io_info = std::make_shared<OpIOInfo>();
205       io_info->set_name("output_" + std::to_string(i));
206       op_info->add_outputs_ptr(io_info);
207     }
208   }
209 
ExcludeAttr(const std::string & name) const210   bool ExcludeAttr(const std::string &name) const {
211     const std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", kAttrOutputNames,
212                                               kAttrInputNames, "is_load"};
213     return black_list.count(name) != 0;
214   }
215 
ExtractAttrs(const OpInfoPtr & op_info)216   void ExtractAttrs(const OpInfoPtr &op_info) {
217     auto prim = GetCNodePrimitive(cnode_);
218     if (prim == nullptr) {
219       return;
220     }
221     for (const auto &[name, v] : prim->attrs()) {
222       if (ExcludeAttr(name)) {
223         continue;
224       }
225       auto op_attr = std::make_shared<OpAttr>();
226       op_attr->set_name(name);
227       op_attr->set_param_type("required");
228       // Only support the following types in op json.
229       if (v->isa<Int32Imm>() || v->isa<Int64Imm>()) {
230         op_attr->set_type("int");
231       } else if (v->isa<FP32Imm>() || v->isa<FP64Imm>()) {
232         op_attr->set_type("float");
233       } else if (v->isa<BoolImm>()) {
234         op_attr->set_type("bool");
235       } else if (v->isa<StringImm>()) {
236         op_attr->set_type("str");
237       } else if (v->isa<Type>()) {
238         // convert the TypeId to string
239         op_attr->set_type("str");
240       } else if (v->isa<ValueSequence>()) {
241         const auto &vec = v->cast<ValueSequencePtr>()->value();
242         if (vec.empty()) {
243           op_attr->set_type("listInt");
244         } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
245           op_attr->set_type("listInt");
246         } else if (vec[0]->isa<StringImm>()) {
247           op_attr->set_type("listStr");
248         }
249       }
250       if (op_attr->type().empty()) {
251         MS_LOG(DEBUG) << "Unknown type, ignore attr: " << name;
252         continue;
253       }
254       op_info->add_attrs_ptr(op_attr);
255     }
256   }
257 
258   CNodePtr cnode_;
259 };
260 }  // namespace
261 
SetValueList(nlohmann::json * node_json,void * data,size_t data_size,TypeId type_id,const AnfNodePtr & cnode)262 void SetValueList(nlohmann::json *node_json, void *data, size_t data_size, TypeId type_id, const AnfNodePtr &cnode) {
263   if (type_id == kNumberTypeInt32) {
264     std::vector<int32_t> vec;
265     auto data_ptr = static_cast<int32_t *>(data);
266     for (size_t i = 0; i < data_size; i++) {
267       vec.push_back(data_ptr[i]);
268     }
269     (*node_json)["value"] = vec;
270   } else if (type_id == kNumberTypeInt64) {
271     std::vector<int64_t> vec;
272     auto data_ptr = static_cast<int64_t *>(data);
273     for (size_t i = 0; i < data_size; i++) {
274       vec.push_back(data_ptr[i]);
275     }
276     (*node_json)["value"] = vec;
277   } else {
278     MS_LOG(EXCEPTION) << "The input data of node " << cnode->DebugString() << " should be an int tensor.";
279   }
280 }
281 
SetSingleValue(nlohmann::json * node_json,void * data,TypeId type_id,const AnfNodePtr & cnode,size_t input_idx)282 void SetSingleValue(nlohmann::json *node_json, void *data, TypeId type_id, const AnfNodePtr &cnode, size_t input_idx) {
283   if (type_id == kFloat64->type_id()) {
284     (*node_json)["value"] = static_cast<double *>(data)[0];
285   } else if (type_id == kFloat32->type_id()) {
286     (*node_json)["value"] = static_cast<float *>(data)[0];
287   } else if (type_id == kFloat16->type_id()) {
288     float16 *val = static_cast<float16 *>(data);
289     (*node_json)["value"] = static_cast<float>(val[0]);
290   } else if (type_id == kUInt64->type_id()) {
291     (*node_json)["value"] = static_cast<uint64_t *>(data)[0];
292   } else if (type_id == kUInt32->type_id()) {
293     (*node_json)["value"] = static_cast<uint32_t *>(data)[0];
294   } else if (type_id == kUInt16->type_id()) {
295     (*node_json)["value"] = static_cast<uint16_t *>(data)[0];
296   } else if (type_id == kUInt8->type_id()) {
297     (*node_json)["value"] = static_cast<uint8_t *>(data)[0];
298   } else if (type_id == kInt64->type_id()) {
299     (*node_json)["value"] = static_cast<int64_t *>(data)[0];
300   } else if (type_id == kInt32->type_id()) {
301     (*node_json)["value"] = static_cast<int32_t *>(data)[0];
302   } else if (type_id == kInt16->type_id()) {
303     (*node_json)["value"] = static_cast<int16_t *>(data)[0];
304   } else if (type_id == kInt8->type_id()) {
305     (*node_json)["value"] = static_cast<int8_t *>(data)[0];
306   } else if (type_id == kBool->type_id()) {
307     (*node_json)["value"] = static_cast<bool *>(data)[0];
308   } else {
309     MS_LOG(EXCEPTION) << "Fail to parse the input value of [" << cnode->DebugString() << "], the input index is "
310                       << input_idx << ", because the value type: " << TypeIdToString(type_id, true)
311                       << " is not in supported list: [float64, float32, float16, uint64, uint32, uint16, uint8, int64, "
312                          "int32, int16, int8, bool].";
313   }
314 }
315 
GetInputTensorValue(const AnfNodePtr & anf_node,size_t input_idx,ShapeVector * input_shape,nlohmann::json * node_json) const316 bool GraphKernelJsonGenerator::GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx,
317                                                    ShapeVector *input_shape, nlohmann::json *node_json) const {
318   MS_EXCEPTION_IF_NULL(anf_node);
319   MS_EXCEPTION_IF_NULL(node_json);
320   auto cnode = anf_node->cast<CNodePtr>();
321   MS_EXCEPTION_IF_NULL(cnode);
322   if (input_idx + 1 >= cnode->size()) {
323     MS_EXCEPTION(ArgumentError) << "Input index " << input_idx << " is out of range [0, " << cnode->size()
324                                 << ") in node [" << cnode->DebugString() << "]";
325   }
326 
327   auto input_node = cnode->input(input_idx + 1);
328   if (!input_node->isa<ValueNode>()) {
329     return false;
330   }
331   auto input_value_node = input_node->cast<ValueNodePtr>();
332   auto &input_value = input_value_node->value();
333   if (input_value->isa<ValueSequence>()) {
334     auto value_seq = input_value->cast<ValueSequencePtr>()->value();
335     ShapeVector vec;
336     vec.reserve(value_seq.size());
337     for (auto v : value_seq) {
338       if (v->isa<Int32Imm>() || v->isa<Int64Imm>()) {
339         (void)vec.emplace_back(AnfUtils::GetIntValue(v));
340       } else {
341         MS_LOG(EXCEPTION) << "Element in valuenode must be int" << input_node->fullname_with_scope();
342       }
343     }
344     (*node_json)["value"] = vec;
345     *input_shape = {static_cast<ShapeValueDType>(vec.size())};
346     return true;
347   } else if (input_value->isa<Int32Imm>() || input_value->isa<Int64Imm>()) {
348     (*node_json)["value"] = AnfUtils::GetIntValue(input_value);
349     *input_shape = {1};
350     return true;
351   } else if (input_value->isa<tensor::Tensor>()) {
352     auto tensor = input_value->cast<tensor::TensorPtr>();
353     if (tensor == nullptr) {
354       MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
355       return false;
356     }
357 
358     auto type_id = tensor->data_type();
359     auto *data = tensor->data_c();
360     auto data_size = tensor->DataSize();
361     MS_EXCEPTION_IF_NULL(data);
362     if (data_size > 1) {
363       SetValueList(node_json, data, data_size, type_id, cnode);
364       *input_shape = ShapeVector{static_cast<ShapeValueDType>(data_size)};
365       return true;
366     }
367     SetSingleValue(node_json, data, type_id, cnode, input_idx);
368     if (GraphKernelFlags::GetInstance().kernel_generator != "AKG_V2") {
369       *input_shape = {1};
370     }
371     return true;
372   }
373   return false;
374 }
375 
QuerySymbolicShapeStr(const AnfNodePtr & node)376 std::vector<std::string> GraphKernelJsonGenerator::QuerySymbolicShapeStr(const AnfNodePtr &node) {
377   auto node_abs = node->abstract();
378   ListSymbolPtr sym_shape = node_abs->GetSymbolicShape();
379   if (sym_shape == nullptr) {
380     sym_shape = node_abs->GetShape()->BuildSymbolicShape();
381     MS_EXCEPTION_IF_NULL(sym_shape);
382   }
383   if (sym_shape->size() == 0) {
384     return {"1"};
385   }
386   std::vector<std::string> res;
387   res.reserve(sym_shape->size());
388   (void)std::transform(sym_shape->symbols().cbegin(), sym_shape->symbols().cend(), std::back_inserter(res),
389                        [](const SymbolPtr &s) { return s->ToRawString(); });
390   return res;
391 }
392 
SaveShape(const AnfNodePtr & node,nlohmann::json * kernel_json,const ShapeVector & shape)393 void GraphKernelJsonGenerator::SaveShape(const AnfNodePtr &node, nlohmann::json *kernel_json,
394                                          const ShapeVector &shape) {
395   if (symbol_engine_ == nullptr) {
396     (*kernel_json)[kJsonKeyShape] = shape;
397     return;
398   }
399   std::vector<std::string> symbol_shape;
400   auto new_shape = shape;
401   if (!IsDynamic(shape)) {
402     symbol_shape.resize(shape.size());
403     (void)std::transform(shape.begin(), shape.end(), symbol_shape.begin(), [](int64_t v) { return std::to_string(v); });
404   } else {
405     symbol_shape = QuerySymbolicShapeStr(node);
406     if (shape.size() != symbol_shape.size()) {
407       MS_LOG(EXCEPTION) << "The length of tensor shape and symbol shape should be equal but got " << shape.size()
408                         << " and " << symbol_shape.size() << ". node: " << node->DebugString() << ", shape: " << shape
409                         << ", symbol_shape: " << symbol_shape;
410     }
411     for (size_t i = 0; i < new_shape.size(); i++) {
412       auto symbol = symbol_shape[i];
413       if (new_shape[i] == abstract::Shape::kShapeDimAny &&
414           std::all_of(symbol.begin(), symbol.end(), [](char c) { return std::isdigit(c); })) {
415         new_shape[i] = std::stoi(symbol);
416       }
417     }
418   }
419   (*kernel_json)[kJsonKeyShape] = new_shape;
420   (*kernel_json)[kJsonKeySymbolicShape] = symbol_shape;
421   (void)symbol_engine_->QuerySymbolExpr(node, &symbol_calc_exprs_);
422 }
423 
CreateInputDescJson(const AnfNodePtr & anf_node,const OpInfoPtr & op_info,nlohmann::json * inputs_json)424 bool GraphKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info,
425                                                    nlohmann::json *inputs_json) {
426   // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
427   auto inputs_ptr = op_info->inputs_ptr();
428   if (inputs_ptr.empty()) {
429     MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] info has no input info";
430     return false;
431   }
432 
433   // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
434   auto dyn_input_sizes = GetDynInputSizes(anf_node);
435   size_t real_input_index = 0;
436   for (size_t i = 0; i < inputs_ptr.size(); i++) {
437     auto input_ptr = inputs_ptr[i];
438     if (input_ptr == nullptr) {
439       MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] input[" << i << "] is nullptr";
440       return false;
441     }
442 
443     size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : LongToSize(dyn_input_sizes[i]);
444     std::vector<nlohmann::json> input_list;
445     for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
446       auto type_id = this->cb_->GetInputType(anf_node, real_input_index);
447       std::string dtype = TypeIdToString(type_id, true);
448       if (dtype.empty()) {
449         MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] input [" << real_input_index
450                       << "] data type is null. ";
451         return false;
452       }
453       nlohmann::json input_desc_json;
454       input_desc_json[kJsonKeyDataType] = dtype;
455       input_desc_json[kJsonKeyFormat] = this->cb_->GetInputFormat(anf_node, real_input_index);
456       input_desc_json[kJsonKeyName] = input_ptr->name();
457       input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
458       auto input_shape = this->cb_->GetInputShape(anf_node, real_input_index);
459       if (!is_basic_op_ && GetInputTensorValue(anf_node, real_input_index, &input_shape, &input_desc_json)) {
460         MS_LOG(DEBUG) << "Pick value [" << input_desc_json[kJsonKeyValue] << "] from input[" << real_input_index
461                       << "] of node [" << anf_node->DebugString(kDebugStrDepth);
462       }
463       if (input_shape.empty()) {
464         input_shape.push_back(1);
465       }
466       SaveShape(anf_node->cast<CNodePtr>()->input(real_input_index + 1), &input_desc_json, input_shape);
467       (void)input_list.emplace_back(input_desc_json);
468       real_input_index++;
469     }
470     (void)inputs_json->emplace_back(input_list);
471   }
472   return true;
473 }
474 
CreateOutputDescJson(const AnfNodePtr & anf_node,const OpInfoPtr & op_info,nlohmann::json * outputs_json)475 bool GraphKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info,
476                                                     nlohmann::json *outputs_json) {
477   MS_EXCEPTION_IF_NULL(anf_node);
478   MS_EXCEPTION_IF_NULL(op_info);
479   MS_EXCEPTION_IF_NULL(outputs_json);
480   size_t output_tensor_num = AnfUtils::GetOutputTensorNum(anf_node);
481 
482   auto outputs = op_info->outputs_ptr();
483   for (size_t i = 0; i < output_tensor_num; i++) {
484     nlohmann::json output_json;
485     auto type_id = this->cb_->GetOutputType(anf_node, i);
486     std::string dtype = TypeIdToString(type_id, true);
487     if (dtype.empty()) {
488       MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] output [" << i << "] data type is null. ";
489       return false;
490     }
491 
492     std::string output_name = outputs[i]->name();
493     output_json[kJsonKeyDataType] = dtype;
494     output_json[kJsonKeyFormat] = this->cb_->GetOutputFormat(anf_node, i);
495     output_json[kJsonKeyName] = output_name;
496     output_json[kJsonKeyTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
497     auto output_shape = this->cb_->GetOutputShape(anf_node, i);
498     if (output_shape.empty()) {
499       output_shape.push_back(1);
500     }
501     SaveShape(anf_node, &output_json, output_shape);
502     outputs_json->push_back(output_json);
503   }
504   return true;
505 }
506 
GetAttrJson(const AnfNodePtr & anf_node,const std::vector<int64_t> & dyn_input_sizes,const OpAttrPtr & op_attr,nlohmann::json * attr_json,const ValuePtr & attr_value)507 void GraphKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int64_t> &dyn_input_sizes,
508                                            const OpAttrPtr &op_attr, nlohmann::json *attr_json,
509                                            const ValuePtr &attr_value) {
510   MS_EXCEPTION_IF_NULL(anf_node);
511   MS_EXCEPTION_IF_NULL(op_attr);
512   MS_EXCEPTION_IF_NULL(attr_json);
513 
514   auto get_int_value = [](const ValuePtr &value) -> int {
515     return value->isa<Int64Imm>() ? static_cast<int>(GetValue<int64_t>(value)) : GetValue<int>(value);
516   };
517   if (IsPrimitiveCNode(anf_node, prim::kPrimCast) && op_attr->name() == "dtype") {
518     (*attr_json)[kJsonKeyName] = "dst_type";
519     (*attr_json)[kJsonKeyValue] = TypeIdToString(TypeId(get_int_value(attr_value)), true);
520     return;
521   }
522   std::string type = op_attr->type();
523   (*attr_json)[kJsonKeyDataType] = type;
524   if (type == "int") {
525     (*attr_json)[kJsonKeyValue] = get_int_value(attr_value);
526   } else if (type == "str") {
527     if (attr_value->isa<Type>()) {
528       (*attr_json)[kJsonKeyValue] = TypeIdToString(attr_value->cast<TypePtr>()->type_id(), true);
529     } else {
530       (*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
531     }
532   } else if (type == "bool") {
533     (*attr_json)[kJsonKeyValue] = GetValue<bool>(attr_value);
534   } else if (type == "float") {
535     (*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
536   } else if (type == "listInt") {
537     std::vector<int> list_int;
538     const auto &vals = attr_value->cast<ValueSequencePtr>()->value();
539     (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
540     (*attr_json)[kJsonKeyValue] = list_int;
541   } else if (type == "listStr") {
542     std::vector<std::string> data_format;
543     if (op_attr->name() == kJsonKeyDataformat) {
544       size_t tensor_args_num =
545         !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node);
546       for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
547         auto input_format = this->cb_->GetInputFormat(anf_node, format_i);
548         data_format.push_back(input_format);
549       }
550     } else {
551       data_format = GetValue<std::vector<std::string>>(attr_value);
552     }
553     (*attr_json)[kJsonKeyValue] = data_format;
554   } else {
555     MS_LOG(WARNING) << "Invalid attr " << op_attr->name() << " found in node " << anf_node->fullname_with_scope()
556                     << ", because its type: " << type
557                     << " is not in supported list: [str, int, bool, float, listInt, listStr].";
558   }
559 }
560 
CreateAttrDescJson(const AnfNodePtr & anf_node,const OpInfoPtr & op_info,nlohmann::json * attrs_json)561 bool GraphKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info,
562                                                   nlohmann::json *attrs_json) {
563   auto attrs = op_info->attrs_ptr();
564   if (attrs.empty()) {
565     MS_LOG(DEBUG) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
566     return true;
567   }
568   auto dyn_input_sizes = GetDynInputSizes(anf_node);
569   auto primitive = GetCNodePrimitive(anf_node);
570 
571   // create input name list for "x_shape" in attr with "x" in primitive.
572   auto inputs = op_info->inputs_ptr();
573   std::map<std::string, size_t> op_info_shape_name;
574   for (size_t i = 0; i < inputs.size(); i++) {
575     op_info_shape_name[inputs[i]->name() + "_shape"] = i;
576   }
577 
578   for (const auto &op_attr : attrs) {
579     nlohmann::json attr_json;
580     ValuePtr attr_value = primitive->GetAttr(op_attr->name());
581     if (attr_value == nullptr && op_attr->name() != kJsonKeyDataformat) {
582       if (op_attr->param_type() != "required") {
583         continue;
584       }
585       // match "x_shape" in attr with "x" in primitive.
586       auto find_item = op_info_shape_name.find(op_attr->name());
587       if (find_item != op_info_shape_name.end()) {
588         if (!dyn_input_sizes.empty()) {
589           if (find_item->second >= dyn_input_sizes.size() - 1) {
590             MS_LOG(EXCEPTION) << "dyn_input_sizes list index " << find_item->second << " is out of range [0, "
591                               << dyn_input_sizes.size() - 1 << ") in node [" << anf_node->fullname_with_scope() << "]";
592             return false;
593           }
594           size_t tensor_idx = LongToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0));
595           for (int64_t input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) {
596             attr_json[kJsonKeyValue] = this->cb_->GetInputInferShape(anf_node, tensor_idx);
597             attr_json[kJsonKeyName] = op_attr->name();
598             attrs_json->push_back(attr_json);
599             tensor_idx++;
600           }
601         } else {
602           attr_json[kJsonKeyValue] = this->cb_->GetInputInferShape(anf_node, find_item->second);
603           attr_json[kJsonKeyName] = op_attr->name();
604           attrs_json->push_back(attr_json);
605         }
606       } else {
607         MS_LOG(ERROR) << "Can not find attr '" << op_attr->name() << "' in node [" << anf_node->fullname_with_scope()
608                       << "]";
609         return false;
610       }
611     } else {
612       attr_json[kJsonKeyName] = op_attr->name();
613       GetAttrJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
614       attrs_json->push_back(attr_json);
615     }
616   }
617   return true;
618 }
619 
GetInputTensorIdxInc(const AnfNodePtr & anf_node,size_t input_idx)620 size_t GraphKernelJsonGenerator::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
621   MS_EXCEPTION_IF_NULL(anf_node);
622   auto cnode = anf_node->cast<CNodePtr>();
623   MS_EXCEPTION_IF_NULL(cnode);
624   if (input_idx + 1 >= cnode->size()) {
625     MS_EXCEPTION(ArgumentError) << "Input index " << input_idx << " is out of range [0, " << cnode->size()
626                                 << ") in node [" << cnode->DebugString() << "]";
627   }
628 
629   auto input_node = cnode->input(input_idx + 1);
630   if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
631     size_t index = input_tensor_idx_.size();
632     input_tensor_idx_[input_node] = index;
633   }
634 
635   return input_tensor_idx_[input_node];
636 }
637 
GetOutputTensorIdxInc()638 size_t GraphKernelJsonGenerator::GetOutputTensorIdxInc() {
639   size_t idx = output_tensor_idx_++;
640   return idx;
641 }
642 
GetTensorName(const nlohmann::json & node_json,const std::string & tag,const std::pair<size_t,size_t> & position) const643 std::string GraphKernelJsonGenerator::GetTensorName(const nlohmann::json &node_json, const std::string &tag,
644                                                     const std::pair<size_t, size_t> &position) const {
645   if (node_json.count(tag) == 0) {
646     MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
647     return "";
648   }
649 
650   auto const &tag_desc = node_json[tag];
651   nlohmann::json first_index;
652   if (tag == kJsonKeyOutputDesc) {
653     first_index = tag_desc;
654   } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
655     MS_LOG(ERROR) << "Access index is out of range: "
656                   << " trying to access index " << position.first << " of node: " << tag_desc.dump();
657     return "";
658   } else {
659     first_index = tag_desc[position.first];
660   }
661 
662   if (!first_index.is_array() || first_index.size() <= position.second) {
663     MS_LOG(ERROR) << "Access index is out of range: "
664                   << " trying to access index " << position.second << " of node: " << first_index.dump();
665     return "";
666   }
667   auto const &second_index = first_index[position.second];
668   if (second_index.count(kJsonKeyTensorName) == 0) {
669     MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kJsonKeyTensorName << "].";
670     return "";
671   }
672 
673   return second_index[kJsonKeyTensorName];
674 }
675 
SetTensorName(const std::string & tag,const std::string & new_name,const std::pair<size_t,size_t> & position,nlohmann::json * node_json) const676 void GraphKernelJsonGenerator::SetTensorName(const std::string &tag, const std::string &new_name,
677                                              const std::pair<size_t, size_t> &position,
678                                              nlohmann::json *node_json) const {
679   MS_EXCEPTION_IF_NULL(node_json);
680   if (node_json->count(tag) == 0) {
681     MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
682     return;
683   }
684 
685   nlohmann::json *tag_desc = &((*node_json)[tag]);
686   nlohmann::json *first_index;
687   if (tag == kJsonKeyOutputDesc) {
688     first_index = tag_desc;
689   } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
690     MS_LOG(ERROR) << "Access index is out of range: "
691                   << " trying to access index " << position.first << " of node: " << tag_desc->dump();
692     return;
693   } else {
694     first_index = &((*tag_desc)[position.first]);
695   }
696 
697   if (!first_index->is_array() || first_index->size() <= position.second) {
698     MS_LOG(ERROR) << "Access index is out of range: "
699                   << " trying to access index " << position.second << " of node: " << first_index->dump();
700     return;
701   }
702   nlohmann::json *second_index = &((*first_index)[position.second]);
703   if (second_index->count(kJsonKeyTensorName) == 0) {
704     MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kJsonKeyTensorName << "].";
705     return;
706   }
707   (*second_index)[kJsonKeyTensorName] = new_name;
708   return;
709 }
710 
SaveNodeAddress(const AnfNodePtr & anf_node,nlohmann::json * node_json)711 void GraphKernelJsonGenerator::SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json) {
712   if (dump_option_.save_ptr_address) {
713     std::ostringstream get_the_address;
714     get_the_address << anf_node.get();
715     auto address = anf_node->UniqueName();
716     (*node_json)[kJsonKeyPtrAddress] = address;
717     address_node_map_[address] = anf_node;
718   }
719 }
720 
ExtractOpInfo(const AnfNodePtr & anf_node) const721 OpInfoPtr GraphKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) const {
722   if (dump_option_.extract_opinfo_from_anfnode) {
723     OpInfoExtractor e;
724     return e.Run(anf_node);
725   } else {
726 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
727     MS_LOG(EXCEPTION) << "OpLib is not supported.";
728 #else
729     OpImplyType imply_type;
730     const auto &flags = GraphKernelFlags::GetInstance();
731 
732     if (flags.kernel_generator == "AKG_V2") {
733       imply_type = OpImplyType::kImplyDynamicAKG;
734     } else {
735       imply_type = OpImplyType::kImplyAKG;
736     }
737     return kernel::OpLib::FindOp(AnfUtils::GetCNodeName(anf_node), imply_type);
738 #endif
739   }
740 }
741 
GetProcessorByTarget() const742 std::string GraphKernelJsonGenerator::GetProcessorByTarget() const {
743   auto target = cb_->GetTargetFromContext();
744   if (target == kGPUDevice) {
745     return "cuda";
746   }
747   if (target == kAscendDevice) {
748     return "aicore";
749   }
750   return "cpu";
751 }
752 
GenerateSingleKernelJson(const AnfNodePtr & anf_node,nlohmann::json * node_json)753 bool GraphKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json) {
754   MS_EXCEPTION_IF_NULL(anf_node);
755   MS_EXCEPTION_IF_NULL(node_json);
756   OpInfoPtr op_info = ExtractOpInfo(anf_node);
757   MS_EXCEPTION_IF_NULL(op_info);
758 
759   // get basic params from currentNodeOpDesc
760   std::string op_name;
761   if (IsPrimitiveCNode(anf_node, prim::kPrimCustom)) {
762     auto primitive = GetCNodePrimitive(anf_node);
763     MS_EXCEPTION_IF_NULL(primitive);
764     op_name = primitive->name();
765   } else {
766     op_name = op_info->op_name();
767   }
768   if (all_ops_name_.empty()) {
769     all_ops_name_ = op_name;
770   } else {
771     static_cast<void>(all_ops_name_.append("_").append(op_name));
772   }
773   (*node_json)[kJsonKeyName] = op_name;
774   (*node_json)[kJsonKeyImplPath] = op_info->impl_path();
775   SaveNodeAddress(anf_node, node_json);
776 
777   // input desc
778   nlohmann::json inputs_json;
779   if (!CreateInputDescJson(anf_node, op_info, &inputs_json)) {
780     MS_LOG(ERROR) << "Create input desc json failed, op[" << anf_node->fullname_with_scope() << "].";
781     return false;
782   }
783   (*node_json)[kJsonKeyInputDesc] = inputs_json;
784   MS_LOG(DEBUG) << "The kernel compiler create input desc json success.";
785 
786   // output desc
787   nlohmann::json outputs_json;
788   if (!CreateOutputDescJson(anf_node, op_info, &outputs_json)) {
789     MS_LOG(ERROR) << "Create output desc json failed, op[" << anf_node->fullname_with_scope() << "].";
790     return false;
791   }
792   (*node_json)[kJsonKeyOutputDesc] = outputs_json;
793   MS_LOG(DEBUG) << "The kernel compiler create output desc json success.";
794 
795   // attribute desc
796   nlohmann::json attrs_json;
797   if (!CreateAttrDescJson(anf_node, op_info, &attrs_json)) {
798     MS_LOG(ERROR) << "Create attr desc json failed, op[" << anf_node->fullname_with_scope() << "].";
799     return false;
800   }
801   (*node_json)[kJsonKeyAttr] = attrs_json;
802   return true;
803 }
804 
GetTensorSize(const nlohmann::json & node_json) const805 size_t GraphKernelJsonGenerator::GetTensorSize(const nlohmann::json &node_json) const {
806   const ShapeVector &shape = node_json[kJsonKeyShape];
807   const std::string &dtype = node_json[kJsonKeyDataType];
808   auto type_ptr = StringToType(dtype);
809   MS_EXCEPTION_IF_NULL(type_ptr);
810   auto num_ptr = type_ptr->cast<NumberPtr>();
811   MS_EXCEPTION_IF_NULL(num_ptr);
812   size_t nbyte = IntToSize(num_ptr->nbits() / static_cast<int>(BitsNum::eBits8));
813   return std::accumulate(shape.begin(), shape.end(), nbyte, std::multiplies<size_t>());
814 }
815 
GetIOSize(const nlohmann::json & node_json,std::vector<size_t> * input_size,std::vector<size_t> * output_size) const816 void GraphKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size,
817                                          std::vector<size_t> *output_size) const {
818   input_size->clear();
819   output_size->clear();
820   for (size_t i = 0; i < node_json[kJsonKeyInputDesc].size(); i++) {
821     for (size_t m = 0; m < node_json[kJsonKeyInputDesc][i].size(); m++) {
822       input_size->push_back(GetTensorSize(node_json[kJsonKeyInputDesc][i][m]));
823     }
824   }
825   for (size_t i = 0; i < node_json[kJsonKeyOutputDesc].size(); i++) {
826     output_size->push_back(GetTensorSize(node_json[kJsonKeyOutputDesc][i]));
827   }
828 }
829 
GenHashId(const std::string & info) const830 size_t GraphKernelJsonGenerator::GenHashId(const std::string &info) const {
831   if (!dump_option_.save_ptr_address) {
832     return std::hash<std::string>()(info);
833   }
834   // gen hash id without node address
835   // the format is like {"ptr_address":"0x12345678"}
836   std::string key = std::string("\"") + kJsonKeyPtrAddress + "\"";
837   std::ostringstream result;
838   size_t begin = 0;
839   size_t pos;
840   while ((pos = info.find(key, begin)) != std::string::npos) {
841     result << info.substr(begin, pos - begin);
842     // skip the address
843     auto addr_begin = info.find('\"', pos + key.size());
844     auto addr_end = info.find('\"', addr_begin + 1);
845     begin = addr_end + 1;
846   }
847   result << info.substr(begin);
848   return std::hash<std::string>()(result.str());
849 }
850 
CollectJson(const AnfNodePtr & anf_node,nlohmann::json * kernel_json)851 bool GraphKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) {
852   MS_EXCEPTION_IF_NULL(anf_node);
853   MS_EXCEPTION_IF_NULL(kernel_json);
854   std::string op_name = AnfUtils::GetCNodeName(anf_node);
855   MS_LOG(DEBUG) << "The kernel compiler start generate kernel json desc, full scope name is : "
856                 << anf_node->fullname_with_scope();
857   is_basic_op_ = true;
858   if (!GenerateSingleKernelJson(anf_node, kernel_json)) {
859     MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
860     return false;
861   }
862   if (dump_option_.get_target_info) {
863     TargetInfoSetter::Set(kernel_json);
864   }
865   (*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
866   (*kernel_json)[kJsonKeyVersion] = kCurrentInfoVersion;
867 
868   // gen hash id with the above info.
869   size_t hash_id = GenHashId(kernel_json->dump());
870   kernel_name_ = op_name + "_" + std::to_string(hash_id);
871   if (dump_option_.gen_kernel_name_only) {
872     return true;
873   }
874   (*kernel_json)[kJsonKeyId] = 0;  // unused key
875   (*kernel_json)[kJsonKeyOp] = kernel_name_;
876   const auto &flags = GraphKernelFlags::GetInstance();
877   (*kernel_json)[kJsonKeyPlatform] = flags.kernel_generator;
878   (*kernel_json)[kJsonKeyComposite] = false;
879 
880   GetIOSize(*kernel_json, &input_size_list_, &output_size_list_);
881 
882   MS_LOG(DEBUG) << "The kernel compiler create kernel json desc success, full scope name is : "
883                 << anf_node->fullname_with_scope() << ", json info name is : " << kernel_name_;
884   return true;
885 }
886 
GenStitchJson(const std::vector<AnfNodePtr> & anf_nodes,std::map<AnfNodePtr,nlohmann::json> * node_json_map,nlohmann::json * kernel_json) const887 void GraphKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes,
888                                              std::map<AnfNodePtr, nlohmann::json> *node_json_map,
889                                              nlohmann::json *kernel_json) const {
890   std::vector<std::string> stitchs;
891   for (auto const &anf_node : anf_nodes) {
892     auto prim = GetCNodePrimitive(anf_node);
893     MS_EXCEPTION_IF_NULL(prim);
894     auto stitch_attr = prim->GetAttr(kAttrStitch);
895     if (stitch_attr != nullptr && GetValue<std::string>(stitch_attr) == "common") {
896       auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0});
897       if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) {
898         (void)stitchs.emplace_back(name);
899       }
900     }
901   }
902   if (!stitchs.empty()) {
903     std::vector<nlohmann::json> v;
904     for (auto &s : stitchs) {
905       std::vector<std::string> t(1, s);
906       (void)v.emplace_back(std::move(t));
907     }
908     nlohmann::json stitch_json;
909     stitch_json[kJsonKeyStitchOp] = v;
910     (*kernel_json)[kJsonKeyBufferStitch] = stitch_json;
911   }
912 }
913 
GenKernelName(const FuncGraphPtr & fg,size_t hash_id,nlohmann::json * kernel_json)914 void GraphKernelJsonGenerator::GenKernelName(const FuncGraphPtr &fg, size_t hash_id, nlohmann::json *kernel_json) {
915   MS_EXCEPTION_IF_NULL(fg);
916   // the final kernel name has a hash_id, and may has a "_more" suffix.
917   // total len is up to about 105, file name (with ".info") is up to 110.
918   constexpr size_t name_len_limited = 80;
919   kernel_name_ = "Fused_";
920   auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
921   std::string ops_name = (attr_val != nullptr) ? GetValue<std::string>(attr_val) : all_ops_name_;
922   if (ops_name.size() > name_len_limited) {
923     (*kernel_json)[kJsonKeyOpFullName] = kernel_name_ + ops_name;
924     auto suffix_pos = ops_name.find_last_of("_");
925     if (suffix_pos != std::string::npos && ops_name.size() - suffix_pos < name_len_limited) {
926       ops_name =
927         ops_name.substr(0, name_len_limited - (ops_name.size() - suffix_pos)) + "_more" + ops_name.substr(suffix_pos);
928     } else {
929       ops_name = ops_name.substr(0, name_len_limited) + "_more";
930     }
931   }
932   (void)kernel_name_.append(ops_name).append("_");
933   (void)kernel_name_.append(std::to_string(hash_id));
934 }
935 
CollectFusedJson(const std::vector<AnfNodePtr> & anf_nodes,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list,nlohmann::json * kernel_json,const bool is_akg_cce)936 bool GraphKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
937                                                 const std::vector<AnfNodePtr> &input_list,
938                                                 const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json,
939                                                 const bool is_akg_cce) {
940   if (anf_nodes.empty()) {
941     MS_LOG(ERROR) << "anf_nodes list is empty";
942     return false;
943   }
944   MS_LOG(DEBUG) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size()
945                 << "], output_list: [" << input_list.size() << "].";
946   std::map<AnfNodePtr, nlohmann::json> node_json_map;
947   is_basic_op_ = false;
948   dump_option_.extract_opinfo_from_anfnode = true;  // always extract from anfnode for composite ops.
949   if (!GenSingleJsons(anf_nodes, &node_json_map)) {
950     return false;
951   }
952 
953   UpdateTensorName(anf_nodes, &node_json_map);
954 
955   std::vector<nlohmann::json> node_json_desc;
956   (void)std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
957                        [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
958   (*kernel_json)[kJsonKeyOpDesc] = node_json_desc;
959 
960   auto inputs_json = CreateInputsJson(anf_nodes, input_list, node_json_map);
961   (*kernel_json)[kJsonKeyInputDesc] = inputs_json;
962   (*kernel_json)[kJsonKeyOutputDesc] =
963     CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map);
964   if (!symbol_calc_exprs_.empty()) {
965     nlohmann::json symbols_json;
966     for (const auto &it : symbol_calc_exprs_) {
967       symbols_json[it.first] = it.second;
968     }
969     (*kernel_json)[kJsonKeySymbolCalcExpr] = symbols_json;
970   }
971   // Add parallel fusion information.
972   GenParallelJson(anf_nodes, input_list, output_list, node_json_map, kernel_json);
973   GenStitchJson(anf_nodes, &node_json_map, kernel_json);
974   if (dump_option_.get_target_info) {
975     TargetInfoSetter::Set(kernel_json);
976   }
977   (*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
978   (*kernel_json)[kJsonKeyVersion] = kCurrentInfoVersion;
979   auto fg = anf_nodes[0]->func_graph();
980   MS_EXCEPTION_IF_NULL(fg);
981   if (fg->has_attr("dynamic_input_index")) {
982     (*kernel_json)[kJsonKeyDynamicInputIndex] = GetValue<std::string>(fg->get_attr("dynamic_input_index"));
983   }
984 
985   // gen hash id with the above info.
986   size_t hash_id = GenHashId(kernel_json->dump());
987   GenKernelName(fg, hash_id, kernel_json);
988   if (dump_option_.gen_kernel_name_only) {
989     return true;
990   }
991   (*kernel_json)[kJsonKeyId] = 0;  // unused key
992   (*kernel_json)[kJsonKeyOp] = kernel_name_;
993   const auto &flags = GraphKernelFlags::GetInstance();
994   (*kernel_json)[kJsonKeyPlatform] = flags.kernel_generator;
995   (*kernel_json)[kJsonKeyComposite] = true;
996   (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
997   if (fg->has_attr(kAttrNodeName)) {
998     (*kernel_json)[kJsonKeyNodeName] = GetValue<std::string>(fg->get_attr(kAttrNodeName));
999   }
1000 
1001   if (is_akg_cce) {
1002     (kernel_json_)["enable_cce_lib"] = true;
1003   }
1004   GetIOSize(*kernel_json, &input_size_list_, &output_size_list_);
1005 
1006   return true;
1007 }
1008 
GenSingleJsons(const std::vector<AnfNodePtr> & anf_nodes,std::map<AnfNodePtr,nlohmann::json> * node_json_map)1009 bool GraphKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes,
1010                                               std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
1011   for (auto const &anf_node : anf_nodes) {
1012     nlohmann::json node_json;
1013     if (!GenerateSingleKernelJson(anf_node, &node_json)) {
1014       MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
1015       return false;
1016     }
1017 
1018     auto primitive = GetCNodePrimitive(anf_node);
1019     MS_EXCEPTION_IF_NULL(primitive);
1020 
1021     (*node_json_map)[anf_node] = node_json;
1022   }
1023   return true;
1024 }
1025 
UpdateTensorName(const std::vector<AnfNodePtr> & anf_nodes,std::map<AnfNodePtr,nlohmann::json> * node_json_map) const1026 void GraphKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes,
1027                                                 std::map<AnfNodePtr, nlohmann::json> *node_json_map) const {
1028   for (auto const &anf_node : anf_nodes) {
1029     auto dyn_input_sizes = GetDynInputSizes(anf_node);
1030     bool is_dynamic_input = !dyn_input_sizes.empty();
1031     size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node);
1032     size_t real_input_index = 0;
1033     for (size_t i = 0; i < input_num; ++i) {
1034       size_t input_tensor_num = is_dynamic_input ? LongToSize(dyn_input_sizes[i]) : 1;
1035       for (size_t j = 0; j < input_tensor_num; ++j) {
1036         auto tmp_input = GetKernelInput(anf_node, real_input_index);
1037         auto tmpi = i;
1038         auto tmpj = j;  // use tmpi and tmpj to evade pclint warning "for statement index variable modified in body."
1039         std::string tensor_name =
1040           GetTensorName((*node_json_map)[anf_node], kJsonKeyInputDesc, std::make_pair(tmpi, tmpj));
1041         if (node_json_map->find(tmp_input.first) != node_json_map->end()) {
1042           std::string new_tensor_name =
1043             GetTensorName((*node_json_map)[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second));
1044           SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(tmpi, tmpj), &((*node_json_map)[anf_node]));
1045           MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
1046                         << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
1047                         << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
1048         } else {
1049           MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
1050                         << anf_node->fullname_with_scope() << "] is out input.";
1051         }
1052         real_input_index++;
1053       }
1054     }
1055   }
1056 }
1057 
CreateInputsJson(const std::vector<AnfNodePtr> & anf_nodes,const std::vector<AnfNodePtr> & input_list,const std::map<AnfNodePtr,nlohmann::json> & node_json_map)1058 nlohmann::json GraphKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes,
1059                                                           const std::vector<AnfNodePtr> &input_list,
1060                                                           const std::map<AnfNodePtr, nlohmann::json> &node_json_map) {
1061   nlohmann::json inputs_json;
1062   auto input_index = GetInputIndex(anf_nodes, input_list);
1063   for (size_t i = 0; i < input_index.size(); ++i) {
1064     auto tmp_input = input_index[i];
1065     auto type_id = this->cb_->GetInputType(tmp_input.first, tmp_input.second.first);
1066     std::string dtype = TypeIdToString(type_id, true);
1067     nlohmann::json input_desc_json;
1068     input_desc_json[kJsonKeyTensorName] =
1069       GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second);
1070     input_desc_json[kJsonKeyDataType] = dtype;
1071     input_desc_json[kJsonKeyFormat] = this->cb_->GetInputFormat(tmp_input.first, tmp_input.second.first);
1072     auto input_shape = this->cb_->GetInputShape(tmp_input.first, tmp_input.second.first);
1073     if (input_shape.empty()) {
1074       input_shape.push_back(1);
1075     }
1076     auto cnode = tmp_input.first->cast<CNodePtr>();
1077     MS_EXCEPTION_IF_NULL(cnode);
1078     SaveShape(cnode->input(tmp_input.second.first + 1), &input_desc_json, input_shape);
1079     (void)inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
1080   }
1081   return inputs_json;
1082 }
1083 
GenParallelJson(const std::vector<AnfNodePtr> & anf_nodes,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list,const std::map<AnfNodePtr,nlohmann::json> & node_json_map,nlohmann::json * kernel_json) const1084 void GraphKernelJsonGenerator::GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes,
1085                                                const std::vector<AnfNodePtr> &input_list,
1086                                                const std::vector<AnfNodePtr> &output_list,
1087                                                const std::map<AnfNodePtr, nlohmann::json> &node_json_map,
1088                                                nlohmann::json *kernel_json) const {
1089   std::map<size_t, std::pair<size_t, std::vector<std::string>>> sub_graphs_info;
1090   std::string fusion_type;
1091   std::vector<std::vector<int>> type_info;
1092 
1093   auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
1094   for (size_t i = 0; i < output_index.size(); ++i) {
1095     auto tmp_output = output_index[i].first;
1096     auto tmp_output_index = output_index[i].second;
1097     bool found = std::any_of(input_list.cbegin(), input_list.cend(),
1098                              [&tmp_output](const AnfNodePtr &in) { return tmp_output == in; });
1099     if (!found) {
1100       auto tcnode = tmp_output->cast<CNodePtr>();
1101       if (tcnode == nullptr) {
1102         return;
1103       }
1104       auto prim = GetCNodePrimitive(tcnode);
1105       MS_EXCEPTION_IF_NULL(prim);
1106       // Get dim info.
1107       if (prim->HasAttr(kAttrParallelDimInfo)) {
1108         auto info = GetValue<std::vector<size_t>>(prim->GetAttr(kAttrParallelDimInfo));
1109         auto info_size = info.size();
1110         if (info_size != kAttrParallelDimInfoSize) {
1111           MS_LOG(EXCEPTION) << "The size of attr " << kAttrParallelDimInfo << " in node ["
1112                             << tcnode->fullname_with_scope() << "] should be " << kAttrParallelDimInfoSize
1113                             << ", but got " << info_size;
1114         }
1115         auto tensor_name =
1116           GetTensorName(node_json_map.at(tmp_output), kJsonKeyOutputDesc, std::make_pair(0, tmp_output_index));
1117         sub_graphs_info[info[0]].second.push_back(tensor_name);
1118         sub_graphs_info[info[0]].first = info[1];
1119       }
1120       // Get fusion type.
1121       if (prim->HasAttr(kAttrParallelFusionType)) {
1122         fusion_type = GetValue<std::string>(prim->GetAttr(kAttrParallelFusionType));
1123       }
1124       // Get fusion type info.
1125       if (prim->HasAttr(kAttrParallelTypeInfo)) {
1126         type_info = GetValue<std::vector<std::vector<int>>>(prim->GetAttr(kAttrParallelTypeInfo));
1127       }
1128     }
1129   }
1130 
1131   if (!sub_graphs_info.empty()) {
1132     nlohmann::json parallel_fusion_json;
1133     parallel_fusion_json[kJsonKeyFusionType] = fusion_type;
1134     parallel_fusion_json[kJsonKeyTypeInfo] = type_info;
1135     std::vector<std::vector<std::string>> sgraphs;
1136     std::vector<size_t> cnums;
1137     (void)std::for_each(
1138       sub_graphs_info.cbegin(), sub_graphs_info.cend(),
1139       [&sgraphs, &cnums](const std::pair<size_t, std::pair<size_t, std::vector<std::string>>> &sg_info) {
1140         sgraphs.push_back(sg_info.second.second);
1141         cnums.push_back(sg_info.second.first);
1142       });
1143     parallel_fusion_json[kJsonKeySubGraph] = sgraphs;
1144     parallel_fusion_json[kJsonKeyCoreNum] = cnums;
1145 
1146     (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
1147   }
1148 }
1149 
CreateOutputsJson(const std::vector<AnfNodePtr> & anf_nodes,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list,const nlohmann::json & inputs_json,const std::map<AnfNodePtr,nlohmann::json> & node_json_map)1150 nlohmann::json GraphKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes,
1151                                                            const std::vector<AnfNodePtr> &input_list,
1152                                                            const std::vector<AnfNodePtr> &output_list,
1153                                                            const nlohmann::json &inputs_json,
1154                                                            const std::map<AnfNodePtr, nlohmann::json> &node_json_map) {
1155   nlohmann::json outputs_json;
1156   auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
1157   for (size_t i = 0; i < output_index.size(); ++i) {
1158     auto tmp_output = output_index[i];
1159     bool found = false;
1160     nlohmann::json output_desc_json;
1161     for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
1162       if (tmp_output.first == input_list[input_i]) {
1163         output_desc_json = inputs_json[input_i][0];
1164         found = true;
1165         break;
1166       }
1167     }
1168     if (!found) {
1169       auto type_id = this->cb_->GetOutputType(tmp_output.first, tmp_output.second);
1170       std::string dtype = TypeIdToString(type_id, true);
1171       output_desc_json[kJsonKeyTensorName] =
1172         GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
1173       output_desc_json[kJsonKeyDataType] = dtype;
1174       output_desc_json[kJsonKeyFormat] = this->cb_->GetOutputFormat(tmp_output.first, tmp_output.second);
1175       auto output_shape = this->cb_->GetOutputShape(tmp_output.first, tmp_output.second);
1176       if (output_shape.empty()) {
1177         output_shape.push_back(1);
1178       }
1179       SaveShape(tmp_output.first, &output_desc_json, output_shape);
1180     }
1181     (void)outputs_json.emplace_back(output_desc_json);
1182   }
1183   return outputs_json;
1184 }
1185 
CollectJson(const AnfNodePtr & anf_node)1186 bool GraphKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) {
1187   kernel_json_ = nlohmann::json();
1188   return CollectJson(anf_node, &kernel_json_);
1189 }
1190 
CollectFusedJson(const std::vector<AnfNodePtr> & anf_nodes,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list,const bool use_akg_cce_lib)1191 bool GraphKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
1192                                                 const std::vector<AnfNodePtr> &input_list,
1193                                                 const std::vector<AnfNodePtr> &output_list,
1194                                                 const bool use_akg_cce_lib) {
1195   kernel_json_ = nlohmann::json();
1196   return CollectFusedJson(anf_nodes, input_list, output_list, &kernel_json_, use_akg_cce_lib);
1197 }
1198 
CollectFusedJsonWithSingleKernel(const CNodePtr & c_node)1199 bool GraphKernelJsonGenerator::CollectFusedJsonWithSingleKernel(const CNodePtr &c_node) {
1200   kernel_json_ = nlohmann::json();
1201   std::vector<AnfNodePtr> node_list, input_list, output_list;
1202   FuncGraphPtr fg = std::get<0>(BuildGraphFromNodes({c_node}));
1203   FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
1204   auto out_cnode = fg->output()->cast<CNodePtr>();
1205   if (out_cnode == nullptr) {
1206     MS_LOG(ERROR) << "Wrong graph generated for kernel [" << c_node->fullname_with_scope()
1207                   << "], output cnode is a null pointer";
1208     return false;
1209   }
1210   // check all inputs in the cnodes: if it is a valuenode, replace it by a parameter
1211   std::set<AnfNodePtr> value_nodes;
1212   auto &inputs = out_cnode->inputs();
1213   for (size_t i = 1; i < inputs.size(); ++i) {
1214     const auto &tnode = inputs[i];
1215     auto tensor = GetValueNode(tnode);
1216     if (tensor) {
1217       (void)value_nodes.insert(tnode);
1218     }
1219   }
1220 
1221   for (const auto &vnode : value_nodes) {
1222     auto parameter = fg->add_parameter();
1223     parameter->set_abstract(vnode->abstract());
1224     parameter->set_kernel_info(vnode->kernel_info_ptr());
1225     (void)mng->Replace(vnode, parameter);
1226   }
1227 
1228   // add new parameter for the same inputs
1229   std::set<AnfNodePtr> inputs_set;
1230   bool changed = false;
1231   for (size_t i = 1; i < out_cnode->size(); i++) {
1232     auto inp = out_cnode->input(i);
1233     if (inputs_set.count(inp) == 0) {
1234       (void)inputs_set.insert(inp);
1235     } else {
1236       auto p = fg->add_parameter();
1237       p->set_abstract(inp->abstract());
1238       p->set_kernel_info(inp->kernel_info_ptr());
1239       out_cnode->set_input(i, p);
1240       changed = true;
1241     }
1242   }
1243   if (changed) {
1244     GkUtils::UpdateFuncGraphManager(mng, fg);
1245   }
1246 
1247   node_list.push_back(out_cnode);
1248   auto out_cnode_inputs = out_cnode->inputs();
1249   (void)input_list.insert(input_list.cbegin(), out_cnode_inputs.cbegin() + 1, out_cnode_inputs.cend());
1250   auto output_num = static_cast<int64_t>(AnfUtils::GetOutputTensorNum(out_cnode));
1251   if (output_num > 1) {
1252     for (int64_t idx = 0; idx < output_num; idx++) {
1253       auto gt =
1254         out_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(idx)});
1255       (void)output_list.emplace_back(std::move(gt));
1256     }
1257   } else {
1258     output_list.push_back(out_cnode);
1259   }
1260 
1261   if (c_node->HasAttr("use_akg_cce")) {
1262     (kernel_json_)["enable_cce_lib"] = true;
1263   }
1264 
1265   return CollectFusedJson(node_list, input_list, output_list, &kernel_json_);
1266 }
1267 
1268 namespace {
GetCpuInfo(nlohmann::json * target_info)1269 void GetCpuInfo(nlohmann::json *target_info) {
1270   const auto &flags = GraphKernelFlags::GetInstance();
1271   std::string target_os = flags.target_os;
1272   std::string arch = flags.cpu_arch;
1273   std::string feature = flags.cpu_feature;
1274   std::string type = flags.cpu_type;
1275   std::set<std::string> valid_os = {"linux", "windows"};
1276   // arch: <{supported-features}, default-feature>
1277   std::map<std::string, std::pair<std::set<std::string>, std::string>> valid_features = {
1278     {"arm", {{"neon"}, "neon"}},
1279     {"aarch64", {{"neon"}, "neon"}},
1280     {"x86_64", {{"sse", "avx", "avx512"}, "avx"}},
1281   };
1282   std::set<std::string> valid_cpu_types = {"core-avx2", "skylake-avx512", "core-avx-i", "haswell", "skylake"};
1283 
1284   if (valid_os.count(target_os) == 0) {
1285     MS_LOG(WARNING) << "GraphKernelFlag: unsupported \"target_os\": " << target_os;
1286     target_os = "linux";
1287   }
1288   if (valid_features.count(arch) == 0) {
1289     if (!arch.empty()) {
1290       MS_LOG(WARNING) << "GraphKernelFlag: unsupported \"cpu_arch\": " << arch;
1291     }
1292 #if defined(__arm__)
1293     arch = "arm";
1294 #elif defined(__aarch64__)
1295     arch = "aarch64";
1296 #else
1297     arch = "x86_64";
1298 #endif
1299   }
1300 
1301   auto &features = valid_features[arch];
1302   if (features.first.count(feature) == 0) {
1303     if (!feature.empty()) {
1304       MS_LOG(WARNING) << "GraphKernelFlag: unsupported \"cpu_feature\": " << feature;
1305     }
1306     feature = features.second;
1307   }
1308 
1309   if (valid_cpu_types.count(type) == 0) {
1310     if (!type.empty()) {
1311       MS_LOG(WARNING) << "GraphKernelFlag: unsupported \"cpu_type\": " << type;
1312       type = "";
1313     }
1314     if (feature == "avx512") {
1315       type = "skylake-avx512";
1316     } else if (feature == "avx") {
1317       type = "core-avx2";
1318     }
1319   }
1320 
1321   (*target_info)[kJsonKeySystem] = target_os;
1322   (*target_info)[kJsonKeyArch] = arch;
1323   (*target_info)[kJsonKeyCpuFeature] = feature;
1324   if (!type.empty()) {
1325     (*target_info)[kJsonKeyCpuType] = type;
1326   }
1327   return;
1328 }
1329 
1330 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
1331 #ifdef ENABLE_GPU
GetGpuInfo(nlohmann::json * target_info)1332 bool GetGpuInfo(nlohmann::json *target_info) {
1333   int major_version = -1;
1334   auto ret = cuDeviceGetAttribute(&major_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, 0);
1335   if (ret != CUDA_SUCCESS) {
1336     const char *msg = nullptr;
1337     cuGetErrorName(ret, &msg);
1338     MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR fail, error message: " << msg;
1339     return false;
1340   }
1341   int minor_version = -1;
1342   auto ret = cuDeviceGetAttribute(&minor_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 0);
1343   if (ret != CUDA_SUCCESS) {
1344     const char *msg = nullptr;
1345     cuGetErrorName(ret, &msg);
1346     MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR fail, error message: " << msg;
1347     return false;
1348   }
1349   int sm_count = -1;
1350   auto ret = cuDeviceGetAttribute(&sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, 0);
1351   if (ret != CUDA_SUCCESS) {
1352     const char *msg = nullptr;
1353     cuGetErrorName(ret, &msg);
1354     MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT fail, error message: " << msg;
1355     return false;
1356   }
1357   if (major_version == -1 || minor_version == -1 || sm_count == -1) {
1358     return false;
1359   } else {
1360     (*target_info)[kJsonKeyComputeCapability] = std::to_string(major_version) + "." + std::to_string(minor_version);
1361     (*target_info)[kJsonKeySmCount] = sm_count;
1362   }
1363   return true;
1364 }
1365 #else
GetGpuInfo(nlohmann::json * target_info)1366 bool GetGpuInfo(nlohmann::json *target_info) { return false; }
1367 #endif
1368 #else
GetGpuInfo(nlohmann::json * target_info)1369 bool GetGpuInfo(nlohmann::json *target_info) {
1370   const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1371     {kGPUDevice, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1372   MS_EXCEPTION_IF_NULL(device_context);
1373   auto deprecated_ptr = device_context->GetDeprecatedInterface();
1374   MS_EXCEPTION_IF_NULL(deprecated_ptr);
1375   auto major_version = deprecated_ptr->GetGPUCapabilityMajor();
1376   auto minor_version = deprecated_ptr->GetGPUCapabilityMinor();
1377   auto sm_count = deprecated_ptr->GetGPUMultiProcessorCount();
1378   if (major_version == -1 || minor_version == -1 || sm_count == -1) {
1379     return false;
1380   } else {
1381     (*target_info)[kJsonKeyComputeCapability] = std::to_string(major_version) + "." + std::to_string(minor_version);
1382     (*target_info)[kJsonKeySmCount] = sm_count;
1383   }
1384   return true;
1385 }
1386 #endif
1387 }  // namespace
1388 
GetTargetInfo()1389 void TargetInfoSetter::GetTargetInfo() {
1390   auto target = Callback::Instance()->GetTargetFromContext(true);
1391   if (target == kGPUDevice) {
1392     has_info_ = GetGpuInfo(&target_info_);
1393     return;
1394   }
1395   if (target == kCPUDevice) {
1396     GetCpuInfo(&target_info_);
1397     return;
1398   }
1399   // ascend
1400   target_info_[kJsonKeyArch] = target;
1401 }
1402 
SetTargetInfo(nlohmann::json * kernel_info) const1403 void TargetInfoSetter::SetTargetInfo(nlohmann::json *kernel_info) const {
1404   if (has_info_) {
1405     (*kernel_info)[kJsonKeyTargetInfo] = target_info_;
1406   }
1407 }
1408 }  // namespace mindspore::graphkernel
1409