1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/ascend/ascend_helper.h"
18 #include <set>
19 #include "common/trans.h"
20 #include "utils/ms_utils.h"
21 #include "utils/check_convert_utils.h"
22 #include "backend/optimizer/common/helper.h"
23 #include "utils/utils.h"
24 #include "runtime/device/kernel_info.h"
25 #include "backend/kernel_compiler/oplib/oplib.h"
26 #include "backend/kernel_compiler/common_utils.h"
27 #include "base/core_ops.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "backend/session/kernel_graph.h"
30 #include "utils/ms_context.h"
31 #include "utils/trace_base.h"
32 namespace mindspore {
33 namespace opt {
34 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
35 namespace {
NeedInsertTransData(const std::vector<size_t> & origin_shape,const std::string & format)36 bool NeedInsertTransData(const std::vector<size_t> &origin_shape, const std::string &format) {
37 bool shape_check = origin_shape.size() > 1 || (origin_shape.size() == 1 && origin_shape[0] % kCubeSize != 0);
38 return kCommonFormatSet.find(format) == kCommonFormatSet.end() && (shape_check || format == kOpFormat_ND_RNN_BIAS);
39 }
40
CreateReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const KernelSelectPtr & kernel_select,const std::vector<size_t> & dst_shape)41 AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
42 const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
43 std::vector<AnfNodePtr> trans_inputs;
44 auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
45 trans_inputs.emplace_back(NewValueNode(prim));
46 trans_inputs.emplace_back(input_node);
47 auto reshape = func_graph->NewCNode(trans_inputs);
48 MS_EXCEPTION_IF_NULL(reshape);
49 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
50 AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
51 AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
52 reshape->set_scope(input_node->scope());
53 kernel_select->SelectKernel(reshape);
54 return reshape;
55 }
56
SetTransNodeAttr(const CNodePtr & trans_node)57 void SetTransNodeAttr(const CNodePtr &trans_node) {
58 MS_EXCEPTION_IF_NULL(trans_node);
59 auto trans_opname = AnfAlgo::GetCNodeName(trans_node);
60 if (trans_opname == kTransDataOpName || trans_opname == kTransDataRNNOpName) {
61 std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
62 std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
63 if (input_format == kOpFormat_DEFAULT) {
64 input_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
65 }
66 if (output_format == kOpFormat_DEFAULT) {
67 output_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
68 }
69 AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
70 AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
71 }
72 }
73
ReFreshInferShape(const AnfNodePtr & trans_node,const AnfNodePtr & node)74 void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
75 MS_EXCEPTION_IF_NULL(trans_node);
76 MS_EXCEPTION_IF_NULL(node);
77 auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
78 if (!real_input_node->isa<CNode>()) {
79 return;
80 }
81 auto op_name = AnfAlgo::GetCNodeName(real_input_node);
82 if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(trans_node) == prim::kPrimReshape->name()) {
83 auto shape = AnfAlgo::GetPrevNodeOutputInferShape(trans_node, 0);
84 auto type = AnfAlgo::GetPrevNodeOutputInferDataType(trans_node, 0);
85 AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
86 }
87 }
88
SetGroupAttr(const ParameterPtr & param,const AnfNodePtr & out_trans,const AnfNodePtr & in_trans,const std::string & dest_format)89 void SetGroupAttr(const ParameterPtr ¶m, const AnfNodePtr &out_trans, const AnfNodePtr &in_trans,
90 const std::string &dest_format) {
91 MS_EXCEPTION_IF_NULL(param);
92 auto fz_group = param->fracz_group();
93 // in the scenario of gradient freezing or infer while training, the parameters are already set with
94 // fracz_group in first graph, so the inserted transdata will trans format from FracZwithgroup(param)
95 // to default and default to FracZwithoutgroup(cnode, such as Conv2D, Opt). These paired TransDatas are
96 // not set with groups attr and cannot be eliminated in EliminateReduntantOp. So to solve this problem,
97 // set the groups and fracz_group attr here for these paired TransData nodes.
98 if (fz_group > 1) {
99 AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), out_trans);
100 AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), out_trans);
101 if (dest_format == kOpFormat_FRAC_Z) {
102 AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fz_group), in_trans);
103 AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fz_group), in_trans);
104 }
105 }
106 }
107
GetTransInputNodePtr(const FuncGraphPtr & func_graph,const CNodePtr & node,size_t index,const KernelSelectPtr & kernel_select)108 AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
109 const KernelSelectPtr &kernel_select) {
110 MS_EXCEPTION_IF_NULL(node);
111 MS_EXCEPTION_IF_NULL(func_graph);
112 auto input_node = AnfAlgo::GetInputNode(node, index);
113 if (HasAbstractMonad(input_node)) {
114 // No transfer for monad inputs.
115 return input_node;
116 }
117 auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
118 MS_EXCEPTION_IF_NULL(node_with_index.first);
119 auto real_input = node_with_index.first;
120 if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
121 input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
122 MS_EXCEPTION_IF_NULL(input_node);
123 AnfAlgo::SetNodeInput(node, input_node, index);
124 }
125 std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
126 std::string dest_format = AnfAlgo::GetInputFormat(node, index);
127 if (NeedInsertTransData(origin_shape, dest_format)) {
128 MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
129 << " To DefaultFormat , index: " << index;
130 auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
131 if (real_input->isa<Parameter>()) {
132 SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
133 }
134 return transdata;
135 }
136 return input_node;
137 }
138
InsertTransOpForSingleOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)139 AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
140 const KernelSelectPtr &kernel_select) {
141 MS_EXCEPTION_IF_NULL(node);
142 MS_EXCEPTION_IF_NULL(func_graph);
143 std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
144 std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
145 if (output_format == kOpFormat_NC1KHKWHWC0) {
146 MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
147 << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
148 }
149 if (NeedInsertTransData(origin_shape, output_format)) {
150 MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
151 return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
152 }
153 return node;
154 }
155
InsertTransOpForMultipleOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & orig_node,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)156 AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node,
157 const AnfNodePtr &node, const KernelSelectPtr &kernel_select) {
158 MS_EXCEPTION_IF_NULL(func_graph);
159 MS_EXCEPTION_IF_NULL(node);
160 auto manager = func_graph->manager();
161 MS_EXCEPTION_IF_NULL(manager);
162 auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node);
163 for (auto &update_state : update_states) {
164 manager->SetEdge(update_state.first, update_state.second, node);
165 }
166 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
167 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
168 size_t out_num = AnfAlgo::GetOutputTensorNum(node);
169 for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
170 std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
171 if (output_format == kOpFormat_NC1KHKWHWC0) {
172 MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
173 << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
174 }
175 auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
176 std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
177 if (NeedInsertTransData(origin_shape, output_format)) {
178 auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
179 if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
180 kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
181 }
182 make_tuple_inputs.push_back(trans_op);
183 } else {
184 // No need insert trans op.
185 make_tuple_inputs.push_back(tuple_getitem);
186 }
187 }
188 AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
189 return make_tuple;
190 }
191 } // namespace
AddTransOpNodeToGraph(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select,size_t insert_index,bool is_insert_input)192 AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
193 const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
194 AnfNodePtr trans_node = nullptr;
195 CNodePtr trans_data = nullptr;
196 MS_EXCEPTION_IF_NULL(node);
197 // Init
198 std::string default_format = kOpFormat_DEFAULT;
199 AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
200 std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
201 std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
202 std::string padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
203 : AnfAlgo::GetOutputReshapeType(node, insert_index);
204 auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
205 : AnfAlgo::GetOutputInferShape(input_node, insert_index);
206 std::string spec_format = is_insert_input ? dst_format : input_format;
207 bool need_padding = trans::IsNeedPadding(spec_format, input_node_out_shape.size());
208 std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
209 ? prim::kPrimTransDataRNN->name()
210 : prim::kPrimTransData->name();
211 if (!need_padding) {
212 // don't need padding insert transdata only
213 trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
214 trans_node = trans_data;
215 } else if (is_insert_input) {
216 // if need padding & is input need insert a transdata
217 // reshape[padding shape] -> transdata[padding shape] -> node
218 auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
219 AnfAlgo::GetInputReshapeType(node, insert_index));
220 auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
221 trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, trans_opname);
222 trans_node = trans_data;
223 trans_data->set_abstract(input_node->abstract());
224 } else {
225 // if need padding & is output need insert a transdata
226 // node -> transdata[padding shape] -> reshape[ori_shape]
227 trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
228 auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
229 trans_node = reshape_node;
230 }
231 if (trans_opname == prim::kPrimTransDataRNN->name()) {
232 AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
233 AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
234 }
235 // refresh the transdata's format to ori format & dst format
236 RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
237 if (!is_insert_input) {
238 ReFreshInferShape(trans_node, node);
239 }
240 return trans_node;
241 }
242
RefreshKernelBuildInfo(const std::string & input_format,const std::string & output_format,const AnfNodePtr & trans_data,const std::string & reshape_type,const TypeId & type_id)243 void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
244 const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) {
245 MS_EXCEPTION_IF_NULL(trans_data);
246 auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
247 MS_EXCEPTION_IF_NULL(ori_build_info);
248 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
249 MS_EXCEPTION_IF_NULL(builder);
250 builder->SetInputsFormat({input_format});
251 builder->SetInputsReshapeType({reshape_type});
252 builder->SetOutputsReshapeType({reshape_type});
253 builder->SetOutputsFormat({output_format});
254 if (type_id != kTypeUnknown) {
255 builder->SetOutputsDeviceType({type_id});
256 builder->SetInputsDeviceType({type_id});
257 }
258 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
259 SetTransNodeAttr(trans_data->cast<CNodePtr>());
260 }
261
NewTransOpNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const KernelSelectPtr & kernel_select,const bool need_padding,const std::string & op_name,const std::vector<int64_t> & perm)262 CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
263 const bool need_padding, const std::string &op_name, const std::vector<int64_t> &perm) {
264 MS_EXCEPTION_IF_NULL(func_graph);
265 MS_EXCEPTION_IF_NULL(input);
266 MS_EXCEPTION_IF_NULL(kernel_select);
267 CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
268 MS_EXCEPTION_IF_NULL(trans_node);
269 auto infer_type = AnfAlgo::GetOutputInferDataType(input, 0);
270
271 auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
272 MS_EXCEPTION_IF_NULL(out_shape_base);
273 ShapeVector out_shape;
274 ShapeVector out_shape_min;
275 ShapeVector out_shape_max;
276 bool is_dynamic_shape = false;
277 if (out_shape_base->isa<abstract::Shape>()) {
278 auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
279 MS_EXCEPTION_IF_NULL(out_shape_ptr);
280 out_shape = out_shape_ptr->shape();
281 if (out_shape_ptr->IsDynamic()) {
282 out_shape_min = out_shape_ptr->min_shape();
283 out_shape_max = out_shape_ptr->max_shape();
284 is_dynamic_shape = true;
285 }
286 }
287
288 if (need_padding) {
289 // if need padding we should set the transdata node's shape to the padding shape
290 auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
291
292 abstract::ShapePtr pad_shape_ptr;
293 ShapeVector pad_shape = trans::PaddingShape(out_shape, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
294 if (is_dynamic_shape) {
295 ShapeVector pad_shape_min = trans::PaddingShape(out_shape_min, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
296 ShapeVector pad_shape_max = trans::PaddingShape(out_shape_max, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
297 pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape, pad_shape_min, pad_shape_max);
298 } else {
299 pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape);
300 }
301 AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {pad_shape_ptr}, trans_node.get());
302 } else {
303 AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {out_shape_base}, trans_node.get());
304 }
305 // special handle for ut
306 if (trans_node->kernel_info() == nullptr) {
307 auto kernel_info = std::make_shared<device::KernelInfo>();
308 trans_node->set_kernel_info(kernel_info);
309 }
310 if (op_name == prim::kPrimTranspose->name()) {
311 AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
312 }
313 if (is_dynamic_shape) {
314 AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), trans_node);
315 AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);
316 AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), trans_node);
317 }
318 AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
319 AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
320 trans_node->set_scope(input->scope());
321 kernel_select->SelectKernel(trans_node);
322 return trans_node;
323 }
324
AddCastOpNodeToGraph(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const std::string & format,const TypeId & input_type,const TypeId & output_type,const abstract::BaseShapePtr & origin_shape,const TypeId & origin_type,const std::string & reshape_type)325 CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
326 const TypeId &input_type, const TypeId &output_type,
327 const abstract::BaseShapePtr &origin_shape, const TypeId &origin_type,
328 const std::string &reshape_type) {
329 MS_EXCEPTION_IF_NULL(func_graph);
330 MS_EXCEPTION_IF_NULL(origin_shape);
331 std::string input_format = format;
332 std::string output_format = format;
333 CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
334 MS_EXCEPTION_IF_NULL(cast);
335 // set kernel build info
336 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
337 builder.SetInputsFormat({input_format});
338 builder.SetOutputsFormat({output_format});
339 builder.SetInputsReshapeType({reshape_type});
340 builder.SetOutputsReshapeType({reshape_type});
341 builder.SetInputsDeviceType({input_type});
342 builder.SetOutputsDeviceType({output_type});
343 builder.SetFusionType(kernel::FusionType::OPAQUE);
344 builder.SetProcessor(kernel::Processor::AICORE);
345 if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
346 builder.SetKernelType(KernelType::TBE_KERNEL);
347 } else {
348 builder.SetKernelType(KernelType::AKG_KERNEL);
349 }
350 // if kernel info is null , it remarks this function is running ut
351 if (cast->kernel_info() == nullptr) {
352 auto kernel_info = std::make_shared<device::KernelInfo>();
353 cast->set_kernel_info(kernel_info);
354 }
355 if (origin_shape->IsDynamic()) {
356 AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cast);
357 AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cast);
358 AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cast);
359 }
360 AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(origin_type), cast);
361 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
362 AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, cast.get());
363 AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
364 AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
365 return cast;
366 }
367
InsertTransOpForOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & orig_node,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)368 AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
369 const KernelSelectPtr &kernel_select) {
370 size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
371 if (outputs_num == 0) {
372 return node;
373 }
374 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
375 // Single output
376 if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
377 auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
378 if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) {
379 kernel_graph->ReplaceInternalOutput(node, new_node);
380 }
381 return new_node;
382 }
383 // Multiple output
384 return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select);
385 }
386
InsertTransOpForInput(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const KernelSelectPtr & kernel_select)387 AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
388 const KernelSelectPtr &kernel_select) {
389 MS_EXCEPTION_IF_NULL(node);
390 auto cnode = node->cast<CNodePtr>();
391 MS_EXCEPTION_IF_NULL(cnode);
392 std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
393 size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
394 for (size_t input_index = 0; input_index < in_num; ++input_index) {
395 // Monad inputs keep unchanged from GetTransInputNodePtr().
396 AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
397 MS_EXCEPTION_IF_NULL(input_node);
398 new_inputs.push_back(input_node);
399 }
400 CNodePtr new_cnode = nullptr;
401 // cnode changed so make a new cnode to differ from original one.
402 auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
403 if (kernel_graph == nullptr) {
404 new_cnode = std::make_shared<CNode>(*cnode);
405 } else {
406 new_cnode = kernel_graph->NewCNode(cnode);
407 }
408 MS_EXCEPTION_IF_NULL(new_cnode);
409 new_cnode->set_inputs(new_inputs);
410 return new_cnode;
411 }
412
InsertCastForInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode)413 CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
414 MS_EXCEPTION_IF_NULL(cnode);
415 MS_EXCEPTION_IF_NULL(func_graph);
416 std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
417 size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
418 for (size_t input_index = 0; input_index < in_num; ++input_index) {
419 auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
420 MS_EXCEPTION_IF_NULL(cur_input);
421 if (HasAbstractMonad(cur_input)) {
422 // No cast for monad inputs.
423 new_inputs.push_back(cur_input);
424 continue;
425 }
426 auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
427 const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
428 TypeId origin_type(kTypeUnknown);
429
430 auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
431 auto real_input_node = kernel_with_index.first;
432 MS_EXCEPTION_IF_NULL(real_input_node);
433 if (kernel::IsWeightBoundary(real_input_node)) {
434 // weight
435 origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
436 if (origin_type == kTypeUnknown) {
437 origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second);
438 }
439 } else {
440 // feature map
441 origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
442 }
443 const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
444 const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
445 // In graph kernel, we check parameter,
446 // the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
447 if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
448 auto cast =
449 AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
450 MS_EXCEPTION_IF_NULL(cast);
451 cast->set_scope(cnode->scope());
452 AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
453 new_inputs.push_back(cast);
454 } else {
455 new_inputs.push_back(cur_input);
456 }
457 }
458 auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
459 CNodePtr new_node = nullptr;
460 if (kernel_graph == nullptr) {
461 new_node = std::make_shared<CNode>(*cnode);
462 } else {
463 new_node = kernel_graph->NewCNode(cnode);
464 }
465 MS_EXCEPTION_IF_NULL(new_node);
466 new_node->set_inputs(new_inputs);
467 return new_node;
468 }
469
CreateTensorMoveOp(const FuncGraphPtr & graph,const AnfNodePtr & node)470 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
471 MS_EXCEPTION_IF_NULL(graph);
472 MS_EXCEPTION_IF_NULL(node);
473 auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
474 std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
475 auto new_node = graph->NewCNode(new_node_inputs);
476 MS_EXCEPTION_IF_NULL(new_node);
477 new_node->set_abstract(node->abstract());
478 new_node->set_scope(node->scope());
479 AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
480 return new_node;
481 }
482 } // namespace opt
483 } // namespace mindspore
484