1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/kernel_select_ascend.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "runtime/device/kernel_info.h"
20 #include "ir/func_graph.h"
21 #include "backend/kernel_compiler/common_utils.h"
22 #include "backend/kernel_compiler/kernel_query.h"
23 #include "backend/kernel_compiler/kernel_build_info.h"
24
25 namespace mindspore {
26 namespace device {
27 namespace ascend {
28 namespace {
29 // sort format according the number of occurrences.
cmp_format_num(const std::pair<std::string,size_t> & a,const std::pair<std::string,size_t> & b)30 bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
31 if (a.second != b.second) {
32 return a.second > b.second;
33 } else if (a.first == kOpFormat_DEFAULT) {
34 return a.second + 1 > b.second;
35 } else if (b.first == kOpFormat_DEFAULT) {
36 return a.second > b.second + 1;
37 }
38 return a.second > b.second;
39 }
40
GetPrimitivePrecision(const CNodePtr & cnode)41 TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
42 auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
43 MS_EXCEPTION_IF_NULL(primitive);
44
45 TypeId except_type = kTypeUnknown;
46 if (primitive->GetAttr(kAttrFixPrecision) != nullptr) {
47 auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision));
48 if (strExceptDtype == "float16") {
49 except_type = kNumberTypeFloat16;
50 } else if (strExceptDtype == "float32") {
51 except_type = kNumberTypeFloat32;
52 } else {
53 MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype;
54 }
55 }
56
57 return except_type;
58 }
59 } // namespace
60
ResetKernelBuildInfo(const CNodePtr & kernel_node)61 void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
62 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
63 for (size_t input_index = 0; input_index < input_num; ++input_index) {
64 auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
65 MS_EXCEPTION_IF_NULL(input_kernel_node);
66 auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
67 if (!kernel::IsWeightBoundary(kernel_with_index.first)) {
68 continue;
69 }
70 // reset format and dtype.
71 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
72 builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
73 builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
74 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
75 }
76 }
77
UpdateKernelInfo(const std::vector<AnfNodePtr> & node_list)78 void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) {
79 for (size_t i = 0; i < node_list.size(); ++i) {
80 // select nodes in subgraph.
81 auto anf_node = node_list[i];
82 MS_EXCEPTION_IF_NULL(anf_node);
83 auto cnode = anf_node->cast<CNodePtr>();
84 MS_EXCEPTION_IF_NULL(cnode);
85 auto fix_precision_type = GetPrimitivePrecision(cnode);
86 if (fix_precision_type != kTypeUnknown) {
87 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
88 kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL);
89
90 for (size_t index = 0; index < kernel_info_list.size(); ++index)
91 // only math the first input
92 if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type &&
93 kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) &&
94 AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) {
95 auto selected_kernel_info_ptr = kernel_info_list[index];
96 ResetKernelBuildInfo(cnode);
97 AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get());
98 SetTensorDeviceInfo(cnode);
99 break;
100 }
101 }
102 }
103 }
104
CanConvertDefaultShapeToNZ(const std::vector<size_t> & shape)105 bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) {
106 for (size_t i = 1; i <= shape.size(); ++i) {
107 if (i > 2) {
108 break;
109 }
110 if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) {
111 return false;
112 }
113 }
114 return true;
115 }
116
DefaultToFracNZAxis(const std::vector<size_t> & ori_shape,const std::vector<int64_t> & axis)117 std::vector<int64_t> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int64_t> &axis) {
118 std::vector<int64_t> frac_nz_axis = axis;
119 auto shape_len = SizeToLong(ori_shape.size());
120 for (size_t i = 0; i < axis.size(); ++i) {
121 auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len;
122 if (axis_idx == shape_len - SizeToLong(kIndex1)) {
123 frac_nz_axis[i] = axis_idx - SizeToLong(kIndex1);
124 frac_nz_axis.push_back(axis_idx + SizeToLong(kIndex2));
125 } else if (axis_idx == shape_len - SizeToLong(kIndex2)) {
126 frac_nz_axis[i] = axis_idx + SizeToLong(kIndex1);
127 frac_nz_axis.push_back(axis_idx + SizeToLong(kIndex2));
128 } else {
129 frac_nz_axis[i] = axis_idx;
130 }
131 }
132 return frac_nz_axis;
133 }
134
GetReducedFracNZShape(const std::vector<size_t> & ori_shape,const std::vector<int64_t> & axis,bool keep_dims)135 std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int64_t> &axis,
136 bool keep_dims) {
137 std::vector<size_t> result;
138 std::set<size_t> positive_idx;
139 for (const auto &a : axis) {
140 positive_idx.insert(a >= 0 ? LongToSize(a) : ori_shape.size() + LongToSize(a));
141 }
142 for (size_t i = 0; i < ori_shape.size(); ++i) {
143 if (positive_idx.count(i) == 0) {
144 result.push_back(ori_shape[i]);
145 } else if (keep_dims) {
146 result.push_back(1);
147 }
148 }
149 return result;
150 }
151
UpdateFracNZReduceOp(const CNodePtr & cnode)152 void UpdateFracNZReduceOp(const CNodePtr &cnode) {
153 MS_EXCEPTION_IF_NULL(cnode);
154 auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
155 if (input_format == kOpFormat_FRAC_NZ) {
156 // Clone primitive to modify it
157 auto prim = GetCNodePrimitive(cnode);
158 auto new_prim = std::make_shared<Primitive>(*prim);
159 auto new_prim_node = NewValueNode(new_prim);
160 cnode->set_input(0, new_prim_node);
161
162 auto axis_value = new_prim->GetAttr(kAttrAxis);
163 std::vector<int64_t> default_axis;
164 if (axis_value->isa<ValueList>()) {
165 auto value_list = dyn_cast<ValueList>(axis_value);
166 for (const auto &item : value_list->value()) {
167 if (item->isa<Int64Imm>()) {
168 default_axis.push_back(GetValue<int64_t>(item));
169 } else {
170 MS_LOG(EXCEPTION) << "GetValue type should be int64";
171 }
172 }
173 } else if (axis_value->isa<ValueTuple>()) {
174 auto value_tuple = dyn_cast<ValueTuple>(axis_value);
175 for (const auto &item : value_tuple->value()) {
176 if (item->isa<Int64Imm>()) {
177 default_axis.push_back(GetValue<int64_t>(item));
178 } else {
179 MS_LOG(EXCEPTION) << "GetValue type should be int64";
180 }
181 }
182 } else {
183 MS_LOG(ERROR) << "Axis attr type is not correct!";
184 }
185 auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
186 std::vector<int64_t> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis);
187 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int64_t>>(frac_nz_axis), cnode);
188 auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
189 if (output_shape.size() == 1) {
190 AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode);
191 }
192 }
193 }
194
GetDefaultFormat(const CNodePtr & kernel_node,std::string * default_format,bool * use_same_format)195 void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) {
196 MS_EXCEPTION_IF_NULL(kernel_node);
197 MS_EXCEPTION_IF_NULL(default_format);
198 MS_EXCEPTION_IF_NULL(use_same_format);
199 std::unordered_map<std::string, size_t> all_input_formats;
200 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
201 for (size_t i = 0; i < input_num; ++i) {
202 auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
203 MS_EXCEPTION_IF_NULL(input_kernel_node);
204 if (!input_kernel_node->isa<Parameter>()) {
205 ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)];
206 continue;
207 }
208 auto para = input_kernel_node->cast<ParameterPtr>();
209 if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
210 ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)];
211 continue;
212 }
213 *use_same_format = false;
214 }
215
216 if (all_input_formats.empty()) {
217 // all inputs are parameter.
218 *default_format = kOpFormat_NC1HWC0;
219 } else {
220 std::vector<std::pair<std::string, size_t>> pairs;
221 for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
222 pairs.emplace_back(std::make_pair(iter->first, iter->second));
223 }
224
225 std::sort(pairs.begin(), pairs.end(), cmp_format_num);
226 *default_format = pairs.begin()->first;
227 }
228
229 for (size_t i = 0; i < input_num; ++i) {
230 auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
231 MS_EXCEPTION_IF_NULL(input_kernel_node);
232 if (!input_kernel_node->isa<Parameter>() ||
233 AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) {
234 continue;
235 }
236 auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0);
237 if (weight_infer_shape.size() < kShape2dDims && *default_format == kOpFormat_FRAC_NZ) {
238 *default_format = kOpFormat_DEFAULT;
239 *use_same_format = true;
240 break;
241 }
242 }
243 }
244
UpdateInputsKernelInfo(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & input_list,const std::string & default_format,bool use_same_format,std::vector<std::string> * graph_input_format,std::vector<TypeId> * graph_input_type)245 void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
246 const std::string &default_format, bool use_same_format,
247 std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) {
248 MS_EXCEPTION_IF_NULL(graph_input_format);
249 MS_EXCEPTION_IF_NULL(graph_input_type);
250 // We set same format to all inputs of graph kernel subgraph, and process this latter.
251 // We set dtype to inputs of graph kernel subgraph same as infer dtypes.
252 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
253 for (size_t i = 0; i < input_num; ++i) {
254 auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
255 MS_EXCEPTION_IF_NULL(input_kernel_node);
256 if (use_same_format) {
257 bool can_convert = true;
258 if (default_format == kOpFormat_FRAC_NZ) {
259 auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
260 if (!CanConvertDefaultShapeToNZ(infer_shape)) {
261 MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead";
262 can_convert = false;
263 }
264 }
265 if (can_convert) {
266 graph_input_format->emplace_back(default_format);
267 } else {
268 graph_input_format->emplace_back(kOpFormat_DEFAULT);
269 }
270 graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
271 continue;
272 }
273
274 if (!input_kernel_node->isa<Parameter>()) {
275 // subgraph parameter from output of other nodes.
276 graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i));
277 graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
278 continue;
279 }
280
281 auto para = input_kernel_node->cast<ParameterPtr>();
282 MS_EXCEPTION_IF_NULL(para);
283 if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
284 // parameter already selected.
285 graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0));
286 graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0));
287 continue;
288 }
289
290 // weight parameter.
291 graph_input_format->push_back(default_format);
292 graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0));
293 }
294
295 for (size_t i = 0; i < input_num; ++i) {
296 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
297 std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
298 std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
299 builder.SetOutputsFormat(outputs_format);
300 builder.SetOutputsDeviceType(outputs_device_type);
301 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
302 }
303 }
304
UpdateEquivFormat(const std::vector<AnfNodePtr> & node_list,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng)305 void UpdateEquivFormat(const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
306 const FuncGraphManagerPtr &mng) {
307 MS_EXCEPTION_IF_NULL(mng);
308 for (size_t i = 0; i < node_list.size(); ++i) {
309 // select nodes in subgraph.
310 auto anf_node = node_list[i];
311 MS_EXCEPTION_IF_NULL(anf_node);
312 auto cnode = anf_node->cast<CNodePtr>();
313 MS_EXCEPTION_IF_NULL(cnode);
314 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
315 SelectKernelInfo(cnode, KernelType::AKG_KERNEL);
316 // Update ReduceSum
317 if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) {
318 continue;
319 }
320 UpdateFracNZReduceOp(cnode);
321 // If ReduceSum's output is 1d and not Default format, convert it to Default format
322 auto out_format = AnfAlgo::GetOutputFormat(cnode, 0);
323 if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
324 continue;
325 }
326 // Insert EquivFormat node, then select kernel info again
327 std::vector<AnfNodePtr> trans_inputs;
328 trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
329 trans_inputs.push_back(cnode);
330 CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
331 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)},
332 {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get());
333 AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node);
334
335 if (trans_node->kernel_info() == nullptr) {
336 trans_node->set_kernel_info(std::make_shared<device::KernelInfo>());
337 }
338 SelectKernelInfo(trans_node, KernelType::AKG_KERNEL);
339 mng->Replace(cnode, trans_node);
340 }
341 }
342
CheckFormatsAndDtypes(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & input_list,const FuncGraphManagerPtr & mng,const std::string & default_format,std::vector<std::string> * graph_input_format,std::vector<TypeId> * graph_input_type,std::vector<bool> * need_update)343 void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
344 const FuncGraphManagerPtr &mng, const std::string &default_format,
345 std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type,
346 std::vector<bool> *need_update) {
347 MS_EXCEPTION_IF_NULL(kernel_node);
348 MS_EXCEPTION_IF_NULL(mng);
349 MS_EXCEPTION_IF_NULL(graph_input_format);
350 MS_EXCEPTION_IF_NULL(graph_input_type);
351 MS_EXCEPTION_IF_NULL(need_update);
352 // check graph input format and dtype use inner ops.
353 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
354 if (graph_input_format->size() != input_num || graph_input_type->size() != input_num ||
355 need_update->size() != input_num) {
356 MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
357 << "], [" << graph_input_format->size() << "] != [" << input_num << "]";
358 }
359 auto &node_users = mng->node_users();
360 for (size_t i = 0; i < input_num; ++i) {
361 auto &input = input_list[i];
362 auto iter = node_users.find(input);
363 if (iter == node_users.end() || iter->second.empty()) {
364 continue;
365 }
366 for (auto &node_user : iter->second) {
367 if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) {
368 // maybe not a real kernel.
369 continue;
370 }
371 auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1));
372 if (user_format != (*graph_input_format)[i]) {
373 MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString() << " of ["
374 << kernel_node->DebugString()
375 << "] selected different format. we use default: " << default_format;
376 (*graph_input_format)[i] = default_format;
377 (*need_update)[i] = true;
378 }
379
380 if (kernel_node->input(i + 1)->isa<Parameter>() ||
381 AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) {
382 continue;
383 }
384
385 TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
386 MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString() << " of ["
387 << kernel_node->DebugString()
388 << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
389 (*graph_input_type)[i] = default_dtype;
390 (*need_update)[i] = true;
391 }
392 }
393 }
394
UpdateFormatsAndDtypes(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<bool> & need_update,const std::vector<std::string> & graph_input_format,const std::vector<TypeId> & graph_input_type)395 void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
396 const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update,
397 const std::vector<std::string> &graph_input_format,
398 const std::vector<TypeId> &graph_input_type) {
399 MS_EXCEPTION_IF_NULL(kernel_node);
400 // update graph input format and dtype use inner ops.
401 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
402 if (graph_input_format.size() != input_num || graph_input_type.size() != input_num ||
403 need_update.size() != input_num) {
404 MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
405 << "], [" << graph_input_format.size() << "] != [" << input_num << "]";
406 }
407 for (size_t i = 0; i < input_num; ++i) {
408 if (!need_update[i]) {
409 continue;
410 }
411
412 MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
413 << "] to: " << graph_input_format[i];
414 MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
415 << "] to: " << TypeIdLabel(graph_input_type[i]);
416 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
417 std::vector<std::string> outputs_format = {graph_input_format[i]};
418 std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
419 builder.SetOutputsFormat(outputs_format);
420 builder.SetOutputsDeviceType(outputs_device_type);
421 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
422 }
423
424 ResetKernelBuildInfo(kernel_node);
425 // select nodes in subgraph again.
426 for (size_t i = 0; i < node_list.size(); ++i) {
427 auto anf_node = node_list[i];
428 MS_EXCEPTION_IF_NULL(anf_node);
429 auto cnode = anf_node->cast<CNodePtr>();
430 MS_EXCEPTION_IF_NULL(cnode);
431 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
432 size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode);
433 for (size_t j = 0; j < cnode_input_num; ++j) {
434 auto input_node = cnode->input(j + 1);
435 MS_EXCEPTION_IF_NULL(input_node);
436 if (!IsValueNode<tensor::Tensor>(input_node)) {
437 continue;
438 }
439 // reset format and dtype of const tensor.
440 builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
441 builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
442 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get());
443 }
444 SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL);
445 }
446 }
447
SetGraphKernelInfo(const CNodePtr & kernel_node,const std::vector<std::pair<AnfNodePtr,size_t>> & output_index,const std::vector<std::string> & graph_input_format,const std::vector<TypeId> & graph_input_type)448 void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
449 const std::vector<std::string> &graph_input_format,
450 const std::vector<TypeId> &graph_input_type) {
451 MS_EXCEPTION_IF_NULL(kernel_node);
452 std::vector<std::string> graph_output_format;
453 std::vector<TypeId> graph_output_type;
454 for (size_t i = 0; i < output_index.size(); ++i) {
455 auto const &output = output_index[i];
456 graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
457 TypeId output_type(kTypeUnknown);
458 if (output.first->isa<CNode>()) {
459 output_type = AnfAlgo::GetCNodeOutputPrecision(output.first);
460 }
461 if (output_type == kTypeUnknown) {
462 output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
463 }
464 graph_output_type.push_back(output_type);
465 }
466
467 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
468 graph_info_builder.SetInputsFormat(graph_input_format);
469 graph_info_builder.SetInputsDeviceType(graph_input_type);
470 graph_info_builder.SetOutputsFormat(graph_output_format);
471 graph_info_builder.SetOutputsDeviceType(graph_output_type);
472 graph_info_builder.SetProcessor(kernel::Processor::AICORE);
473 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
474 graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
475 auto graph_selected_info = graph_info_builder.Build();
476 MS_EXCEPTION_IF_NULL(graph_selected_info);
477 AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
478 SetTensorDeviceInfo(kernel_node);
479 }
480
SelectGraphKernelInfo(const CNodePtr & kernel_node,const FuncGraphPtr & func_graph)481 void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
482 MS_EXCEPTION_IF_NULL(kernel_node);
483 MS_EXCEPTION_IF_NULL(func_graph);
484
485 // collect input info of funcgraph
486 std::vector<AnfNodePtr> node_list;
487 std::vector<AnfNodePtr> input_list;
488 std::vector<AnfNodePtr> output_list;
489 kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
490 if (input_list.size() != kernel_node->inputs().size() - 1) {
491 MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode["
492 << kernel_node->DebugString() << "], [%" << input_list.size() << "] != ["
493 << kernel_node->inputs().size() << "]";
494 }
495
496 std::string default_format;
497 bool use_same_format = true;
498 GetDefaultFormat(kernel_node, &default_format, &use_same_format);
499 MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format
500 << "] for ParameterWeight.";
501
502 std::vector<std::string> graph_input_format;
503 std::vector<TypeId> graph_input_type;
504 UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
505 &graph_input_type);
506
507 auto mng = func_graph->manager();
508 if (mng == nullptr) {
509 mng = Manage(func_graph, true);
510 }
511 UpdateEquivFormat(node_list, func_graph, mng);
512 node_list.clear();
513 input_list.clear();
514 output_list.clear();
515 kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
516
517 // update graph input format and dtype use inner ops.
518 std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false);
519 CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type,
520 &need_update);
521 UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type);
522
523 // set fix_precision for kernel when the me prim has fix_precision attr
524 UpdateKernelInfo(node_list);
525
526 auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
527 SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
528 }
529 } // namespace ascend
530 } // namespace device
531 } // namespace mindspore
532