• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/ascend/ascend_helper.h"
18 #include <set>
19 #include "common/trans.h"
20 #include "utils/ms_utils.h"
21 #include "utils/check_convert_utils.h"
22 #include "backend/optimizer/common/helper.h"
23 #include "utils/utils.h"
24 #include "runtime/device/kernel_info.h"
25 #include "backend/kernel_compiler/oplib/oplib.h"
26 #include "backend/kernel_compiler/common_utils.h"
27 #include "base/core_ops.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "backend/session/kernel_graph.h"
30 #include "utils/ms_context.h"
31 #include "utils/trace_base.h"
32 namespace mindspore {
33 namespace opt {
34 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
35 namespace {
NeedInsertTransData(const std::vector<size_t> & origin_shape,const std::string & format)36 bool NeedInsertTransData(const std::vector<size_t> &origin_shape, const std::string &format) {
37   bool shape_check = origin_shape.size() > 1 || (origin_shape.size() == 1 && origin_shape[0] % kCubeSize != 0);
38   return kCommonFormatSet.find(format) == kCommonFormatSet.end() && (shape_check || format == kOpFormat_ND_RNN_BIAS);
39 }
40 
CreateReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const KernelSelectPtr & kernel_select,const std::vector<size_t> & dst_shape)41 AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
42                              const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
43   std::vector<AnfNodePtr> trans_inputs;
44   auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
45   trans_inputs.emplace_back(NewValueNode(prim));
46   trans_inputs.emplace_back(input_node);
47   auto reshape = func_graph->NewCNode(trans_inputs);
48   MS_EXCEPTION_IF_NULL(reshape);
49   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
50   AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
51   AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
52   reshape->set_scope(input_node->scope());
53   kernel_select->SelectKernel(reshape);
54   return reshape;
55 }
56 
SetTransNodeAttr(const CNodePtr & trans_node)57 void SetTransNodeAttr(const CNodePtr &trans_node) {
58   MS_EXCEPTION_IF_NULL(trans_node);
59   auto trans_opname = AnfAlgo::GetCNodeName(trans_node);
60   if (trans_opname == kTransDataOpName || trans_opname == kTransDataRNNOpName) {
61     std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
62     std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
63     if (input_format == kOpFormat_DEFAULT) {
64       input_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
65     }
66     if (output_format == kOpFormat_DEFAULT) {
67       output_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
68     }
69     AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
70     AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
71   }
72 }
73 
ReFreshInferShape(const AnfNodePtr & trans_node,const AnfNodePtr & node)74 void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
75   MS_EXCEPTION_IF_NULL(trans_node);
76   MS_EXCEPTION_IF_NULL(node);
77   auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
78   if (!real_input_node->isa<CNode>()) {
79     return;
80   }
81   auto op_name = AnfAlgo::GetCNodeName(real_input_node);
82   if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(trans_node) == prim::kPrimReshape->name()) {
83     auto shape = AnfAlgo::GetPrevNodeOutputInferShape(trans_node, 0);
84     auto type = AnfAlgo::GetPrevNodeOutputInferDataType(trans_node, 0);
85     AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
86   }
87 }
88 
SetGroupAttr(const ParameterPtr & param,const AnfNodePtr & out_trans,const AnfNodePtr & in_trans,const std::string & dest_format)89 void SetGroupAttr(const ParameterPtr &param, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
90                   const std::string &dest_format) {
91   MS_EXCEPTION_IF_NULL(param);
92   auto fz_group = param->fracz_group();
93   // in the scenario of gradient freezing or infer while training, the parameters are already set with
94   // fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
95   // to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
96   // not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
97   // set the groups and fracz_group attr here for these paired TransData nodes.
98   if (fz_group > 1) {
99     AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
100     AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
101     if (dest_format == kOpFormat_FRAC_Z) {
102       AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
103       AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
104     }
105   }
106 }
107 
GetTransInputNodePtr(const FuncGraphPtr & func_graph,const CNodePtr & node,size_t index,const KernelSelectPtr & kernel_select)108 AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
109                                 const KernelSelectPtr &kernel_select) {
110   MS_EXCEPTION_IF_NULL(node);
111   MS_EXCEPTION_IF_NULL(func_graph);
112   auto input_node = AnfAlgo::GetInputNode(node, index);
113   if (HasAbstractMonad(input_node)) {
114     // No transfer for monad inputs.
115     return input_node;
116   }
117   auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
118   MS_EXCEPTION_IF_NULL(node_with_index.first);
119   auto real_input = node_with_index.first;
120   if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
121     input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
122     MS_EXCEPTION_IF_NULL(input_node);
123     AnfAlgo::SetNodeInput(node, input_node, index);
124   }
125   std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
126   std::string dest_format = AnfAlgo::GetInputFormat(node, index);
127   if (NeedInsertTransData(origin_shape, dest_format)) {
128     MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
129                   << " To DefaultFormat , index: " << index;
130     auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
131     if (real_input->isa<Parameter>()) {
132       SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
133     }
134     return transdata;
135   }
136   return input_node;
137 }
138 
InsertTransOpForSingleOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)139 AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
140                                         const KernelSelectPtr &kernel_select) {
141   MS_EXCEPTION_IF_NULL(node);
142   MS_EXCEPTION_IF_NULL(func_graph);
143   std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
144   std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
145   if (output_format == kOpFormat_NC1KHKWHWC0) {
146     MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
147                       << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
148   }
149   if (NeedInsertTransData(origin_shape, output_format)) {
150     MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
151     return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
152   }
153   return node;
154 }
155 
InsertTransOpForMultipleOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & orig_node,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)156 AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node,
157                                           const AnfNodePtr &node, const KernelSelectPtr &kernel_select) {
158   MS_EXCEPTION_IF_NULL(func_graph);
159   MS_EXCEPTION_IF_NULL(node);
160   auto manager = func_graph->manager();
161   MS_EXCEPTION_IF_NULL(manager);
162   auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node);
163   for (auto &update_state : update_states) {
164     manager->SetEdge(update_state.first, update_state.second, node);
165   }
166   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
167   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
168   size_t out_num = AnfAlgo::GetOutputTensorNum(node);
169   for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
170     std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
171     if (output_format == kOpFormat_NC1KHKWHWC0) {
172       MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
173                         << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
174     }
175     auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
176     std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
177     if (NeedInsertTransData(origin_shape, output_format)) {
178       auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
179       if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
180         kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
181       }
182       make_tuple_inputs.push_back(trans_op);
183     } else {
184       // No need insert trans op.
185       make_tuple_inputs.push_back(tuple_getitem);
186     }
187   }
188   AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
189   return make_tuple;
190 }
191 }  // namespace
AddTransOpNodeToGraph(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select,size_t insert_index,bool is_insert_input)192 AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
193                                  const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
194   AnfNodePtr trans_node = nullptr;
195   CNodePtr trans_data = nullptr;
196   MS_EXCEPTION_IF_NULL(node);
197   // Init
198   std::string default_format = kOpFormat_DEFAULT;
199   AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
200   std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
201   std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
202   std::string padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
203                                              : AnfAlgo::GetOutputReshapeType(node, insert_index);
204   auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
205                                               : AnfAlgo::GetOutputInferShape(input_node, insert_index);
206   std::string spec_format = is_insert_input ? dst_format : input_format;
207   bool need_padding = trans::IsNeedPadding(spec_format, input_node_out_shape.size());
208   std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
209                                ? prim::kPrimTransDataRNN->name()
210                                : prim::kPrimTransData->name();
211   if (!need_padding) {
212     // don't need padding insert transdata only
213     trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
214     trans_node = trans_data;
215   } else if (is_insert_input) {
216     // if need padding & is input need insert a transdata
217     // reshape[padding shape] -> transdata[padding shape] -> node
218     auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
219                                              AnfAlgo::GetInputReshapeType(node, insert_index));
220     auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
221     trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, trans_opname);
222     trans_node = trans_data;
223     trans_data->set_abstract(input_node->abstract());
224   } else {
225     // if need padding & is output need insert a transdata
226     // node -> transdata[padding shape] -> reshape[ori_shape]
227     trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
228     auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
229     trans_node = reshape_node;
230   }
231   if (trans_opname == prim::kPrimTransDataRNN->name()) {
232     AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
233     AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
234   }
235   // refresh the transdata's format to ori format & dst format
236   RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
237   if (!is_insert_input) {
238     ReFreshInferShape(trans_node, node);
239   }
240   return trans_node;
241 }
242 
RefreshKernelBuildInfo(const std::string & input_format,const std::string & output_format,const AnfNodePtr & trans_data,const std::string & reshape_type,const TypeId & type_id)243 void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
244                             const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) {
245   MS_EXCEPTION_IF_NULL(trans_data);
246   auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
247   MS_EXCEPTION_IF_NULL(ori_build_info);
248   auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
249   MS_EXCEPTION_IF_NULL(builder);
250   builder->SetInputsFormat({input_format});
251   builder->SetInputsReshapeType({reshape_type});
252   builder->SetOutputsReshapeType({reshape_type});
253   builder->SetOutputsFormat({output_format});
254   if (type_id != kTypeUnknown) {
255     builder->SetOutputsDeviceType({type_id});
256     builder->SetInputsDeviceType({type_id});
257   }
258   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
259   SetTransNodeAttr(trans_data->cast<CNodePtr>());
260 }
261 
NewTransOpNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const KernelSelectPtr & kernel_select,const bool need_padding,const std::string & op_name,const std::vector<int64_t> & perm)262 CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
263                         const bool need_padding, const std::string &op_name, const std::vector<int64_t> &perm) {
264   MS_EXCEPTION_IF_NULL(func_graph);
265   MS_EXCEPTION_IF_NULL(input);
266   MS_EXCEPTION_IF_NULL(kernel_select);
267   CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
268   MS_EXCEPTION_IF_NULL(trans_node);
269   auto infer_type = AnfAlgo::GetOutputInferDataType(input, 0);
270 
271   auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
272   MS_EXCEPTION_IF_NULL(out_shape_base);
273   ShapeVector out_shape;
274   ShapeVector out_shape_min;
275   ShapeVector out_shape_max;
276   bool is_dynamic_shape = false;
277   if (out_shape_base->isa<abstract::Shape>()) {
278     auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
279     MS_EXCEPTION_IF_NULL(out_shape_ptr);
280     out_shape = out_shape_ptr->shape();
281     if (out_shape_ptr->IsDynamic()) {
282       out_shape_min = out_shape_ptr->min_shape();
283       out_shape_max = out_shape_ptr->max_shape();
284       is_dynamic_shape = true;
285     }
286   }
287 
288   if (need_padding) {
289     // if need padding we should set the transdata node's shape to the padding shape
290     auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
291 
292     abstract::ShapePtr pad_shape_ptr;
293     ShapeVector pad_shape = trans::PaddingShape(out_shape, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
294     if (is_dynamic_shape) {
295       ShapeVector pad_shape_min = trans::PaddingShape(out_shape_min, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
296       ShapeVector pad_shape_max = trans::PaddingShape(out_shape_max, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
297       pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape, pad_shape_min, pad_shape_max);
298     } else {
299       pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape);
300     }
301     AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {pad_shape_ptr}, trans_node.get());
302   } else {
303     AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {out_shape_base}, trans_node.get());
304   }
305   // special handle for ut
306   if (trans_node->kernel_info() == nullptr) {
307     auto kernel_info = std::make_shared<device::KernelInfo>();
308     trans_node->set_kernel_info(kernel_info);
309   }
310   if (op_name == prim::kPrimTranspose->name()) {
311     AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
312   }
313   if (is_dynamic_shape) {
314     AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), trans_node);
315     AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);
316     AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), trans_node);
317   }
318   AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
319   AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
320   trans_node->set_scope(input->scope());
321   kernel_select->SelectKernel(trans_node);
322   return trans_node;
323 }
324 
AddCastOpNodeToGraph(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const std::string & format,const TypeId & input_type,const TypeId & output_type,const abstract::BaseShapePtr & origin_shape,const TypeId & origin_type,const std::string & reshape_type)325 CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
326                               const TypeId &input_type, const TypeId &output_type,
327                               const abstract::BaseShapePtr &origin_shape, const TypeId &origin_type,
328                               const std::string &reshape_type) {
329   MS_EXCEPTION_IF_NULL(func_graph);
330   MS_EXCEPTION_IF_NULL(origin_shape);
331   std::string input_format = format;
332   std::string output_format = format;
333   CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
334   MS_EXCEPTION_IF_NULL(cast);
335   // set kernel build info
336   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
337   builder.SetInputsFormat({input_format});
338   builder.SetOutputsFormat({output_format});
339   builder.SetInputsReshapeType({reshape_type});
340   builder.SetOutputsReshapeType({reshape_type});
341   builder.SetInputsDeviceType({input_type});
342   builder.SetOutputsDeviceType({output_type});
343   builder.SetFusionType(kernel::FusionType::OPAQUE);
344   builder.SetProcessor(kernel::Processor::AICORE);
345   if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
346     builder.SetKernelType(KernelType::TBE_KERNEL);
347   } else {
348     builder.SetKernelType(KernelType::AKG_KERNEL);
349   }
350   // if kernel info is null , it remarks this function is running ut
351   if (cast->kernel_info() == nullptr) {
352     auto kernel_info = std::make_shared<device::KernelInfo>();
353     cast->set_kernel_info(kernel_info);
354   }
355   if (origin_shape->IsDynamic()) {
356     AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cast);
357     AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cast);
358     AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cast);
359   }
360   AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(origin_type), cast);
361   AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
362   AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, cast.get());
363   AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
364   AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
365   return cast;
366 }
367 
InsertTransOpForOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & orig_node,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)368 AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
369                                   const KernelSelectPtr &kernel_select) {
370   size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
371   if (outputs_num == 0) {
372     return node;
373   }
374   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
375   // Single output
376   if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
377     auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
378     if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) {
379       kernel_graph->ReplaceInternalOutput(node, new_node);
380     }
381     return new_node;
382   }
383   // Multiple output
384   return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select);
385 }
386 
InsertTransOpForInput(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)387 AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
388                                  const KernelSelectPtr &kernel_select) {
389   MS_EXCEPTION_IF_NULL(node);
390   auto cnode = node->cast<CNodePtr>();
391   MS_EXCEPTION_IF_NULL(cnode);
392   std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
393   size_t in_num = AnfAlgo::GetInputNum(cnode);  // include monads.
394   for (size_t input_index = 0; input_index < in_num; ++input_index) {
395     // Monad inputs keep unchanged from GetTransInputNodePtr().
396     AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
397     MS_EXCEPTION_IF_NULL(input_node);
398     new_inputs.push_back(input_node);
399   }
400   CNodePtr new_cnode = nullptr;
401   // cnode changed so make a new cnode to differ from original one.
402   auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
403   if (kernel_graph == nullptr) {
404     new_cnode = std::make_shared<CNode>(*cnode);
405   } else {
406     new_cnode = kernel_graph->NewCNode(cnode);
407   }
408   MS_EXCEPTION_IF_NULL(new_cnode);
409   new_cnode->set_inputs(new_inputs);
410   return new_cnode;
411 }
412 
InsertCastForInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode)413 CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
414   MS_EXCEPTION_IF_NULL(cnode);
415   MS_EXCEPTION_IF_NULL(func_graph);
416   std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
417   size_t in_num = AnfAlgo::GetInputNum(cnode);  // include monads.
418   for (size_t input_index = 0; input_index < in_num; ++input_index) {
419     auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
420     MS_EXCEPTION_IF_NULL(cur_input);
421     if (HasAbstractMonad(cur_input)) {
422       // No cast for monad inputs.
423       new_inputs.push_back(cur_input);
424       continue;
425     }
426     auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
427     const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
428     TypeId origin_type(kTypeUnknown);
429 
430     auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
431     auto real_input_node = kernel_with_index.first;
432     MS_EXCEPTION_IF_NULL(real_input_node);
433     if (kernel::IsWeightBoundary(real_input_node)) {
434       // weight
435       origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
436       if (origin_type == kTypeUnknown) {
437         origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second);
438       }
439     } else {
440       // feature map
441       origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
442     }
443     const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
444     const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
445     // In graph kernel, we check parameter,
446     // the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
447     if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
448       auto cast =
449         AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
450       MS_EXCEPTION_IF_NULL(cast);
451       cast->set_scope(cnode->scope());
452       AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
453       new_inputs.push_back(cast);
454     } else {
455       new_inputs.push_back(cur_input);
456     }
457   }
458   auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
459   CNodePtr new_node = nullptr;
460   if (kernel_graph == nullptr) {
461     new_node = std::make_shared<CNode>(*cnode);
462   } else {
463     new_node = kernel_graph->NewCNode(cnode);
464   }
465   MS_EXCEPTION_IF_NULL(new_node);
466   new_node->set_inputs(new_inputs);
467   return new_node;
468 }
469 
CreateTensorMoveOp(const FuncGraphPtr & graph,const AnfNodePtr & node)470 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
471   MS_EXCEPTION_IF_NULL(graph);
472   MS_EXCEPTION_IF_NULL(node);
473   auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
474   std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
475   auto new_node = graph->NewCNode(new_node_inputs);
476   MS_EXCEPTION_IF_NULL(new_node);
477   new_node->set_abstract(node->abstract());
478   new_node->set_scope(node->scope());
479   AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
480   return new_node;
481 }
482 }  // namespace opt
483 }  // namespace mindspore
484