• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/pass/insert_type_transform_op.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include "abstract/ops/primitive_infer_map.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/convert_utils.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/anf.h"
28 #include "kernel/common_utils.h"
29 #include "kernel/framework_utils.h"
30 #include "ops/arithmetic_ops.h"
31 #include "ops/nn_ops.h"
32 #include "ops/sequence_ops.h"
33 #include "ops/framework_ops.h"
34 #include "ops/op_def.h"
35 #include "ops/op_utils.h"
36 
37 namespace mindspore {
38 namespace opt {
39 namespace {
IsNewKernel(const AnfNodePtr & node)40 bool IsNewKernel(const AnfNodePtr &node) {
41   MS_EXCEPTION_IF_NULL(node);
42   if (!node->isa<CNode>() || common::AnfAlgo::IsCallNode(node)) {
43     return false;
44   }
45   const auto &primitive = GetCNodePrimitive(node);
46   MS_EXCEPTION_IF_NULL(primitive);
47   return mindspore::ops::GetOpDef(primitive->name()) != nullptr;
48 }
49 }  // namespace
SplitTupleInputsForInsertType(const FuncGraphPtr & graph,const AnfNodePtr & tuple_input,std::vector<AnfNodePtr> * plant_inputs)50 int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
51                                       std::vector<AnfNodePtr> *plant_inputs) {
52   MS_EXCEPTION_IF_NULL(graph);
53   MS_EXCEPTION_IF_NULL(tuple_input);
54   MS_EXCEPTION_IF_NULL(plant_inputs);
55 
56   if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
57     auto abs = tuple_input->abstract();
58     MS_EXCEPTION_IF_NULL(abs);
59     MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
60     return -1;
61   }
62 
63   auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
64   if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
65     auto make_tuple = tuple_input->cast<CNodePtr>();
66     MS_EXCEPTION_IF_NULL(make_tuple);
67     size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
68     for (size_t j = 0; j < tuple_input_num; ++j) {
69       // using for graph kernel
70       auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
71       MS_EXCEPTION_IF_NULL(dyn_input_node);
72       // Handle tuple nested scenes.
73       if (dyn_input_node->isa<CNode>() && (common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple) ||
74                                            common::AnfAlgo::IsTupleOutput(dyn_input_node))) {
75         int64_t dyn_input_size = SplitTupleInputsForInsertType(graph, dyn_input_node, plant_inputs);
76         input_size += LongToSize(dyn_input_size);
77         continue;
78       }
79       (void)plant_inputs->emplace_back(dyn_input_node);
80     }
81     return SizeToLong(input_size);
82   }
83   for (size_t index = 0; index < input_size; ++index) {
84     auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
85     MS_LOG(DEBUG) << "Create TupleGetItem node " << dynamic_input_node->fullname_with_scope() << " for tuple node "
86                   << tuple_input->fullname_with_scope();
87     // The virtual node's object types should be set.
88     SetKernelInfoForNewCNode(dynamic_input_node, false);
89     (void)plant_inputs->emplace_back(dynamic_input_node);
90   }
91   return SizeToLong(input_size);
92 }
93 
CreateNewNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & input_list,const CNodePtr & origin_node)94 AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_list,
95                          const CNodePtr &origin_node) {
96   MS_EXCEPTION_IF_NULL(func_graph);
97   MS_EXCEPTION_IF_NULL(origin_node);
98 
99   auto new_cnode = NewCNode(input_list, func_graph, {origin_node});
100   MS_EXCEPTION_IF_NULL(new_cnode);
101   // This pass should not have new node whose abstract differs from the original node. So set the original node's
102   // abstract.
103   new_cnode->set_abstract(origin_node->abstract());
104   new_cnode->set_scope(origin_node->scope());
105   new_cnode->set_primal_attrs(origin_node->primal_attrs());
106   new_cnode->set_attrs(origin_node->attrs());
107   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
108   if (kernel_graph != nullptr) {
109     const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(origin_node);
110     if (front_node != nullptr && front_node->isa<CNode>()) {
111       const auto front_cnode = front_node->cast<CNodePtr>();
112       MS_EXCEPTION_IF_NULL(front_cnode);
113       MS_LOG(INFO) << "Add replace real kernel flag for front node:" << front_node->DebugString();
114       front_cnode->AddAttr(kAttrReplaceRealKernelInBackend, MakeValue(true));
115     }
116     kernel_graph->FrontBackendlMapUpdate(origin_node, new_cnode);
117   }
118 
119   // Inherit from origin kernel build info.
120   KernelBuildInfoPtr origin_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
121   MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
122   auto new_kernel_builder = std::make_shared<KernelBuildInfoBuilder>(origin_kernel_build_info);
123   MS_EXCEPTION_IF_NULL(new_kernel_builder);
124 
125   auto kernel_info = std::make_shared<device::KernelInfo>();
126   MS_EXCEPTION_IF_NULL(kernel_info);
127   new_cnode->set_kernel_info(kernel_info);
128   AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
129 
130   // Need to reset new cnode's kernel build info because the inputs type and number could be changed after processing
131   // methods. Only reset input types.
132   auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
133   auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0));
134   if (IsPrimitiveEquals(new_prim, origin_prim) && !kernel::IsDynamicParamKernel(origin_prim->name()) &&
135       (origin_kernel_build_info->op_type() != kernel::OpType::SKIP)) {
136     SetKernelInfoForNewCNode(new_cnode, false);
137   } else {
138     SetKernelInfoForNewCNode(new_cnode, true);
139   }
140 
141   // If the primitive is not changed, this means only inputs are updated. So inherit output from origin node.
142   if (IsPrimitiveEquals(new_prim, origin_prim)) {
143     KernelBuildInfoPtr new_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(new_cnode);
144     KernelBuildInfoPtr origin_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
145     new_node_build_info->SetOutputsFormat(origin_node_build_info->GetAllOutputFormats());
146     new_node_build_info->SetOutputsDeviceType(origin_node_build_info->GetAllOutputDeviceTypes());
147     new_node_build_info->SetOutputsKernelObjectType(origin_node_build_info->GetAllOutputKernelObjectTypes());
148   }
149 
150   return new_cnode;
151 }
152 
CreateRealMakeTupleByMakeTuple(const FuncGraphPtr & func_graph,const CNodePtr & make_tuple_node)153 AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple_node) {
154   MS_EXCEPTION_IF_NULL(func_graph);
155   MS_EXCEPTION_IF_NULL(make_tuple_node);
156 
157   // Create RealMakeTuple node and inherit inputs and abstract from MakeTuple node.
158   AnfNodePtrList inputs = make_tuple_node->inputs();
159   auto prim = NewValueNode(prim::kPrimRealMakeTuple);
160   MS_EXCEPTION_IF_NULL(prim);
161   if (inputs.empty()) {
162     MS_LOG(EXCEPTION) << "inputs is empty.";
163   }
164   inputs[kIndex0] = prim;
165   CNodePtr real_make_tuple = func_graph->NewCNode(inputs);
166   MS_EXCEPTION_IF_NULL(real_make_tuple);
167   real_make_tuple->set_scope(make_tuple_node->scope());
168   real_make_tuple->set_abstract(make_tuple_node->abstract());
169 
170   SetKernelInfoForNewCNode(real_make_tuple);
171 
172   // RealMakeTuple's inputs must be scalar or tensor. To avoid failing to select kernel, we must override
173   // RealMakeTuple's KernelObjectTypes to TENSOR, which is created from MakeTuple.
174   KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
175   MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
176   auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes();
177   if (!std::all_of(inputs_obj_types.begin(), inputs_obj_types.end(),
178                    [](const auto &obj_type) { return obj_type == KernelObjectType::TENSOR; }) &&
179       !std::all_of(inputs_obj_types.begin(), inputs_obj_types.end(),
180                    [](const auto &obj_type) { return obj_type == KernelObjectType::SCALAR; })) {
181     auto new_obj_types = inputs_obj_types;
182     (void)std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(),
183                          [](const auto &) { return KernelObjectType::TENSOR; });
184     real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types);
185     MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " "
186                   << new_obj_types;
187   }
188   return real_make_tuple;
189 }
190 
CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr & func_graph,const AnfNodePtr & node_with_tuple_unfold_output)191 AnfNodePtr CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr &func_graph,
192                                                  const AnfNodePtr &node_with_tuple_unfold_output) {
193   MS_EXCEPTION_IF_NULL(func_graph);
194   MS_EXCEPTION_IF_NULL(node_with_tuple_unfold_output);
195 
196   auto prim = NewValueNode(prim::kPrimRealMakeTuple);
197   MS_EXCEPTION_IF_NULL(prim);
198   AnfNodePtrList inputs = {prim, node_with_tuple_unfold_output};
199   CNodePtr real_make_tuple = func_graph->NewCNode(inputs);
200   MS_EXCEPTION_IF_NULL(real_make_tuple);
201   real_make_tuple->set_scope(node_with_tuple_unfold_output->scope());
202   // Inherit abstract from TupleUnfold output node.
203   real_make_tuple->set_abstract(node_with_tuple_unfold_output->abstract());
204 
205   SetKernelInfoForNewCNode(real_make_tuple);
206 
207   // Set object type to TupleUnfold so TupleUnfoldToTupleUnfold pattern will be matched.
208   KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
209   MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
210   real_make_tuple_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE_UNFOLD});
211 
212   // Extend tuple_unfold inputs.
213   abstract::AbstractTuplePtr tuple_unfold_abs =
214     node_with_tuple_unfold_output->abstract()->cast<abstract::AbstractTuplePtr>();
215   MS_EXCEPTION_IF_NULL(tuple_unfold_abs);
216   auto builder = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
217   MS_EXCEPTION_IF_NULL(builder);
218   std::vector<std::string> inputs_format{tuple_unfold_abs->size(), builder->GetInputFormat(kIndex0)};
219   std::vector<TypeId> inputs_type{tuple_unfold_abs->size(), builder->GetInputDeviceType(kIndex0)};
220   builder->SetInputsFormat(inputs_format);
221   builder->SetInputsDeviceType(inputs_type);
222 
223   return real_make_tuple;
224 }
225 
IsBackOffOp(const CNodePtr & cnode)226 bool IsBackOffOp(const CNodePtr &cnode) {
227   std::vector<std::string> back_off_op_list = {prim::kPrimTupleToTensor->name(), prim::kPrimScalarToTensor->name(),
228                                                prim::kPrimTensorToTuple->name(), prim::kPrimTensorToScalar->name(),
229                                                prim::kPrimRealMakeTuple->name(), prim::kPrimRealTupleGetItem->name(),
230                                                prim::kPrimTupleSetItem->name()};
231   if (std::find(back_off_op_list.begin(), back_off_op_list.end(), common::AnfAlgo::GetCNodeName(cnode)) !=
232       back_off_op_list.end()) {
233     return true;
234   }
235   return false;
236 }
237 
SetBackOffFlag(const KernelBuildInfoPtr & build_info,const CNodePtr & cnode)238 void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode) {
239   if (IsBackOffOp(cnode)) {
240     build_info->set_valid(false);
241   }
242 }
243 
SetKernelInfoForNewCNode(const CNodePtr & cnode,bool set_format_type)244 void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
245   MS_EXCEPTION_IF_NULL(cnode);
246   // In some cases cnode is newly created and has no kernel info.
247   if (cnode->kernel_info() == nullptr ||
248       (!dynamic_cast<device::KernelInfo *>(cnode->kernel_info())->has_build_info())) {
249     auto kernel_info = std::make_shared<device::KernelInfo>();
250     MS_EXCEPTION_IF_NULL(kernel_info);
251     cnode->set_kernel_info(kernel_info);
252     auto builder = std::make_shared<KernelBuildInfoBuilder>();
253     MS_EXCEPTION_IF_NULL(builder);
254     AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
255   }
256   KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
257   MS_EXCEPTION_IF_NULL(build_info);
258   MS_LOG(DEBUG) << "Start setting kernel info for cnode " << cnode->DebugString() << " " << cnode->fullname_with_scope()
259                 << ",set_format_type: " << set_format_type;
260   // Set input and output object type for subsequent type matching process.
261   std::vector<KernelObjectType> input_obj_type;
262   std::vector<KernelObjectType> output_obj_type;
263   GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
264   build_info->SetInputsKernelObjectType(input_obj_type);
265   build_info->SetOutputsKernelObjectType(output_obj_type);
266 
267   if (set_format_type) {
268     // Set input and output format.
269     std::vector<std::string> inputs_format;
270     std::vector<TypeId> inputs_type;
271     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
272     for (size_t input_index = 0; input_index < input_num; ++input_index) {
273       auto input_node = common::AnfAlgo::GetInputNode(cnode, input_index);
274       MS_EXCEPTION_IF_NULL(input_node);
275 
276       auto real_input = common::AnfAlgo::VisitKernelWithReturnType(input_node, kIndex0);
277       auto real_input_node = real_input.first;
278       MS_EXCEPTION_IF_NULL(real_input_node);
279       auto output_index = real_input.second;
280       if (real_input_node->kernel_info() == nullptr) {
281         (void)inputs_format.emplace_back(kOpFormat_DEFAULT);
282       } else {
283         (void)inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input_node, output_index));
284       }
285       inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
286     }
287 
288     std::vector<std::string> outputs_format;
289     std::vector<TypeId> outputs_type;
290     size_t output_num;
291     if (output_obj_type[kIndex0] == KernelObjectType::TUPLE_UNFOLD) {
292       output_num = AnfAlgo::GetOutputElementNum(cnode);
293     } else {
294       output_num = kSizeOne;
295     }
296     for (size_t output_index = 0; output_index < output_num; ++output_index) {
297       (void)outputs_format.emplace_back(GenerateOutputFormatForNewCNode(cnode));
298       outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
299     }
300 
301     build_info->SetInputsFormat(inputs_format);
302     build_info->SetInputsDeviceType(inputs_type);
303     build_info->SetOutputsFormat(outputs_format);
304     build_info->SetOutputsDeviceType(outputs_type);
305   }
306 
307   // The node may not be supported in the current device.
308   SetBackOffFlag(build_info, cnode);
309   MS_LOG(INFO) << "Set kernel info for cnode " << cnode->DebugString() << " " << cnode->fullname_with_scope() << " "
310                << build_info->ToString();
311 }
312 
SetKernelInfoForValueNode(const ValueNodePtr & value_node)313 void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
314   MS_EXCEPTION_IF_NULL(value_node);
315   auto kernel_info = std::make_shared<device::KernelInfo>();
316   MS_EXCEPTION_IF_NULL(kernel_info);
317   value_node->set_kernel_info(kernel_info);
318   auto builder = std::make_shared<KernelBuildInfoBuilder>();
319   MS_EXCEPTION_IF_NULL(builder);
320 
321   auto type_id = value_node->value()->type()->type_id();
322   std::vector<std::string> inputs_format = {kOpFormat_DEFAULT};
323   std::vector<TypeId> inputs_type = {type_id};
324   std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
325   std::vector<TypeId> outputs_type = {type_id};
326 
327   auto abs_type = AnfAlgo::GetAbstractObjectType(value_node->abstract());
328   std::vector<KernelObjectType> input_obj_type = {kernel::TypeIdToKernelObjectType(abs_type)};
329   std::vector<KernelObjectType> output_obj_type = {kernel::TypeIdToKernelObjectType(abs_type)};
330 
331   builder->SetInputsFormat(inputs_format);
332   builder->SetInputsDeviceType(inputs_type);
333   builder->SetOutputsFormat(outputs_format);
334   builder->SetOutputsDeviceType(outputs_type);
335   builder->SetInputsKernelObjectType(input_obj_type);
336   builder->SetOutputsKernelObjectType(output_obj_type);
337   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
338 }
339 
GenerateAbsByOpInfer(const PrimitivePtr & primitive,const AnfNodePtrList & input_list)340 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
341   MS_EXCEPTION_IF_NULL(primitive);
342   std::vector<AbstractBasePtr> input_args;
343   (void)std::for_each(input_list.begin(), input_list.end(),
344                       [&input_args](const auto &input) { (void)input_args.emplace_back(input->abstract()); });
345   auto abs_opt = abstract::TryInferAbstract(primitive, input_args);
346   if (!abs_opt.has_value()) {
347     MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
348   }
349   auto abs = abs_opt.value();
350   MS_EXCEPTION_IF_NULL(abs);
351   MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
352   return abs;
353 }
354 
GenerateOutputFormatForNewCNode(const CNodePtr & cnode)355 std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode) {
356   MS_EXCEPTION_IF_NULL(cnode);
357   if (IsPrimitiveCNode(cnode, prim::kPrimRealMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimTupleToTensor)) {
358     // We take first input format as the output format because multiple types and formats of
359     // RealMakeTuple/TupleToTensor are not supported.
360     std::string represent_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, kIndex0);
361     return represent_format;
362   }
363   return kOpFormat_DEFAULT;
364 }
365 
GenerateKernelObjectTypeForNewCNode(const CNodePtr & cnode,std::vector<KernelObjectType> * input_obj_type,std::vector<KernelObjectType> * output_obj_type)366 void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type,
367                                          std::vector<KernelObjectType> *output_obj_type) {
368   MS_EXCEPTION_IF_NULL(cnode);
369   MS_EXCEPTION_IF_NULL(input_obj_type);
370   MS_EXCEPTION_IF_NULL(output_obj_type);
371 
372   // Simply trasverse all inputs and get their object types.
373   // But if the input's object type is not set, this will throw exception so must pay attention when using this
374   // function.
375   auto general_input_obj_type_func = [&]() {
376     for (size_t i = kIndex1; i < cnode->size(); i++) {
377       auto input_node = cnode->input(i);
378       MS_EXCEPTION_IF_NULL(input_node);
379       // Set input kernel object type as input node's output kernel object type.
380       if (input_node->kernel_info() == nullptr ||
381           (!dynamic_cast<device::KernelInfo *>(input_node->kernel_info())->has_build_info())) {
382         auto abs_type = AnfAlgo::GetAbstractObjectType(input_node->abstract());
383         input_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
384       } else {
385         input_obj_type->push_back(AnfAlgo::GetOutputKernelObjectType(input_node, kIndex0));
386       }
387     }
388   };
389 
390   if (IsPrimitiveCNode(cnode, prim::kPrimRealMakeTuple)) {
391     general_input_obj_type_func();
392     output_obj_type->push_back(KernelObjectType::TUPLE);
393   } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleToTensor)) {
394     general_input_obj_type_func();
395     output_obj_type->push_back(KernelObjectType::TENSOR);
396   } else if (IsPrimitiveCNode(cnode, prim::kPrimTensorToTuple)) {
397     general_input_obj_type_func();
398     output_obj_type->push_back(KernelObjectType::TUPLE);
399   } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
400     // First input of TupleGetItem must be TUPLE_UNFOLD.
401     // Second is the index.
402     *input_obj_type = {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR};
403     // Get actual output type of TupleGetItem node.
404     auto abs_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
405     output_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
406   } else if (IsPrimitiveCNode(cnode, prim::kPrimRealTupleGetItem)) {
407     general_input_obj_type_func();
408     // Get actual output type of RealTupleGetItem node.
409     auto abs_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
410     output_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
411   } else if (IsPrimitiveCNode(cnode, prim::kPrimTensorToScalar)) {
412     general_input_obj_type_func();
413     output_obj_type->push_back(KernelObjectType::SCALAR);
414   } else {
415     // For other ops, set TENSOR as output object type by default.
416     general_input_obj_type_func();
417     output_obj_type->push_back(KernelObjectType::TENSOR);
418   }
419 
420   MS_LOG(DEBUG) << "Generate input and output object types for new node " << cnode->fullname_with_scope() << " "
421                 << cnode->DebugString() << ". Input object types: " << *input_obj_type
422                 << ". Output object types: " << *output_obj_type;
423 }
424 
ConstructInputByValueNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input)425 AnfNodePtr ConstructInputByValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
426   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
427   if (kernel_graph == nullptr || (!input->isa<ValueNode>())) {
428     return nullptr;
429   }
430   const auto &value_node = input->cast<ValueNodePtr>();
431   MS_EXCEPTION_IF_NULL(value_node);
432   const auto &value = value_node->value();
433   MS_EXCEPTION_IF_NULL(value);
434   if (value->isa<Scalar>()) {
435     return CreateTensorInput(kernel_graph, input);
436   }
437   if (!value->isa<ValueSequence>()) {
438     return nullptr;
439   }
440   const auto &value_sequence = value->cast<ValueSequencePtr>();
441   MS_EXCEPTION_IF_NULL(value_sequence);
442   if (value_sequence->size() == 0) {
443     return nullptr;
444   }
445   const auto &value0 = value_sequence->value()[0];
446   if (value0 == nullptr || (!value0->isa<Scalar>())) {
447     return nullptr;
448   }
449   const auto &scalar0 = value0->cast<ScalarPtr>();
450   MS_EXCEPTION_IF_NULL(scalar0);
451   const auto &type0 = scalar0->type();
452   MS_EXCEPTION_IF_NULL(type0);
453   const auto &type_id0 = type0->type_id();
454   if (std::any_of(value_sequence->value().begin() + 1, value_sequence->value().end(),
455                   [type_id0](const ValuePtr &value) {
456                     return value == nullptr || (!value->isa<Scalar>()) || value->cast<ScalarPtr>()->type() == nullptr ||
457                            value->cast<ScalarPtr>()->type()->type_id() != type_id0;
458                   })) {
459     return nullptr;
460   }
461   return CreateTensorInput(kernel_graph, input);
462 }
463 
464 // A map of kernel object type pairs to processing functions.
465 static std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
466 
467 // The nodes of which object types should be handled.
468 const std::vector<PrimitivePtr> need_handled_types = {prim::kPrimMakeTuple, prim::kPrimTupleGetItem};
469 
InsertTypeTransformOp(bool multigraph)470 InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
471     : PatternProcessPass("insert_type_transform_op", multigraph) {
472   kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
473     std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
474               std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
475   kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE}] =
476     std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTuple, this, std::placeholders::_1, std::placeholders::_2,
477               std::placeholders::_3, std::placeholders::_4);
478   kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR}] =
479     std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTensor, this, std::placeholders::_1, std::placeholders::_2,
480               std::placeholders::_3, std::placeholders::_4);
481   kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TUPLE_UNFOLD}] =
482     std::bind(&InsertTypeTransformOp::ProcessTupleToTupleUnfold, this, std::placeholders::_1, std::placeholders::_2,
483               std::placeholders::_3, std::placeholders::_4);
484   kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TENSOR}] =
485     std::bind(&InsertTypeTransformOp::ProcessTupleToTensor, this, std::placeholders::_1, std::placeholders::_2,
486               std::placeholders::_3, std::placeholders::_4);
487   kTypePairToProcessFunc[{KernelObjectType::SCALAR, KernelObjectType::TENSOR}] =
488     std::bind(&InsertTypeTransformOp::ProcessScalarToTensor, this, std::placeholders::_1, std::placeholders::_2,
489               std::placeholders::_3, std::placeholders::_4);
490   kTypePairToProcessFunc[{KernelObjectType::TENSOR, KernelObjectType::TUPLE}] =
491     std::bind(&InsertTypeTransformOp::ProcessTensorToTuple, this, std::placeholders::_1, std::placeholders::_2,
492               std::placeholders::_3, std::placeholders::_4);
493   kTypePairToProcessFunc[{KernelObjectType::TENSOR, KernelObjectType::SCALAR}] =
494     std::bind(&InsertTypeTransformOp::ProcessTensorToScalar, this, std::placeholders::_1, std::placeholders::_2,
495               std::placeholders::_3, std::placeholders::_4);
496 }
497 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const498 const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
499                                                 const EquivPtr &) const {
500   MS_EXCEPTION_IF_NULL(func_graph);
501   MS_EXCEPTION_IF_NULL(node);
502   if (!node->isa<CNode>()) {
503     return nullptr;
504   }
505   if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
506     return nullptr;
507   }
508   if ((node->kernel_info() == nullptr) ||
509       (!dynamic_cast<device::KernelInfo *>(node->kernel_info())->has_build_info()) ||
510       (common::AnfAlgo::GetCNodeName(node) == "MakeTuple")) {
511     return nullptr;
512   }
513 
514   auto cnode = node->cast<CNodePtr>();
515   MS_EXCEPTION_IF_NULL(cnode);
516   AnfNodePtrList new_input_list = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
517   // If kernel object types are matched, set this flag to true and new node will be created to replace original
518   // node.
519   bool matched = false;
520   for (size_t i = 0; i < common::AnfAlgo::GetInputNum(cnode); ++i) {
521     const auto &input_node = common::AnfAlgo::GetInputNode(cnode, i);
522     // Skip for monad input.
523     if (HasAbstractMonad(input_node) || (node->kernel_info() == nullptr) ||
524         !dynamic_cast<device::KernelInfo *>(node->kernel_info())) {
525       new_input_list.push_back(input_node);
526       continue;
527     }
528 
529     const auto &real_input_node =
530       common::AnfAlgo::VisitKernelWithReturnType(input_node, kIndex0, false, need_handled_types).first;
531     MS_EXCEPTION_IF_NULL(real_input_node);
532     if ((real_input_node->kernel_info() == nullptr) ||
533         (!dynamic_cast<device::KernelInfo *>(real_input_node->kernel_info())->has_build_info())) {
534       MS_LOG(DEBUG) << node->fullname_with_scope() << " input index:" << i
535                     << ", input node:" << real_input_node->fullname_with_scope() << " doesn't have build info.";
536       new_input_list.push_back(input_node);
537       continue;
538     }
539 
540     auto needed_input_type = AnfAlgo::GetInputKernelObjectType(node, i);
541     auto current_input_type = AnfAlgo::GetOutputKernelObjectType(real_input_node, kIndex0);
542     if ((kObjectTypeToString.count(needed_input_type) == 0) || (kObjectTypeToString.count(current_input_type) == 0)) {
543       MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type "
544                         << needed_input_type << " is not valid for node " << node->fullname_with_scope()
545                         << " input index:" << i << ", input node:" << real_input_node->fullname_with_scope();
546     }
547     MS_LOG(DEBUG) << "The current input object type:" << kObjectTypeToString[current_input_type]
548                   << ", needed input object type:" << kObjectTypeToString[needed_input_type]
549                   << " for node:" << node->fullname_with_scope() << " input index:" << i
550                   << ", input node:" << real_input_node->fullname_with_scope();
551 
552     ObjectTypePair type_pair = {current_input_type, needed_input_type};
553     if (kTypePairToProcessFunc.count(type_pair) != 0) {
554       MS_LOG(INFO) << "Kernel object type pair of input index " << i << " for node pair "
555                    << input_node->fullname_with_scope() << " to " << cnode->fullname_with_scope() << " is "
556                    << type_pair.to_string();
557       bool new_prim = false;
558       AnfNodePtrList processed_input_list = kTypePairToProcessFunc[type_pair](func_graph, input_node, cnode, &new_prim);
559       if (IsInputUpdated(input_node, processed_input_list)) {
560         matched = true;
561       }
562       if (new_prim) {
563         MS_LOG(DEBUG) << "New primtive is " << processed_input_list[kIndex0]->fullname_with_scope() << " to replace "
564                       << new_input_list[kIndex0]->fullname_with_scope();
565         // If new primitive is created, replace the old one, which is the first element of the input list.
566         new_input_list[kIndex0] = processed_input_list[kIndex0];
567         // Jump the primitive node the first one, and the rest is the new inputs.
568         (void)new_input_list.insert(new_input_list.end(), std::begin(processed_input_list) + kIndex1,
569                                     processed_input_list.end());
570       } else {
571         (void)new_input_list.insert(new_input_list.end(), processed_input_list.begin(), processed_input_list.end());
572       }
573     } else {
574       // If this input type is valid, just push back the origin input.
575       new_input_list.push_back(input_node);
576     }
577   }
578 
579   if (matched) {
580     // Create replacing node, update front-end node map, set kernel build info, inherit attributes, etc. These
581     // operations could rely on the origin CNode.
582     auto new_node = CreateNewNode(func_graph, new_input_list, cnode);
583     MS_LOG(INFO) << "Create new node " << new_node->fullname_with_scope() << " " << new_node->DebugString()
584                  << " to replace " << cnode->fullname_with_scope() << " " << cnode->DebugString();
585     return new_node;
586   }
587   return nullptr;
588 }
589 
IsInputUpdated(const AnfNodePtr & origin_input,const AnfNodePtrList & new_input_list) const590 bool InsertTypeTransformOp::IsInputUpdated(const AnfNodePtr &origin_input, const AnfNodePtrList &new_input_list) const {
591   MS_EXCEPTION_IF_NULL(origin_input);
592   if (new_input_list.empty()) {
593     MS_LOG(INFO) << "The new input list size should be at least 1, but got 0.";
594     return false;
595   }
596 
597   if (new_input_list.size() == kSizeOne && new_input_list[kIndex0] == origin_input) {
598     MS_LOG(DEBUG) << "Input node " << origin_input->fullname_with_scope() << " " << origin_input->DebugString()
599                   << " should not be updated.";
600     return false;
601   }
602   MS_LOG(DEBUG) << "Input node " << origin_input->fullname_with_scope() << " " << origin_input->DebugString()
603                 << " will be replaced.";
604   return true;
605 }
606 
ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)607 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr &func_graph,
608                                                                       const AnfNodePtr &input, const CNodePtr &node,
609                                                                       bool *) {
610   MS_EXCEPTION_IF_NULL(func_graph);
611   MS_EXCEPTION_IF_NULL(input);
612   MS_EXCEPTION_IF_NULL(node);
613 
614   // If the input needs to be skipped as ConvertTupleInputToDynamicInput does, return the input node itself for
615   // caller to construct input list.
616   bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
617   bool skip = (is_bprop_cut && input->abstract()->isa<abstract::AbstractSparseTensor>()) ||
618               IsPrimitiveCNode(node, prim::kPrimTupleGetItem);
619   if (skip) {
620     return {input};
621   }
622 
623   AnfNodePtrList plant_inputs;
624   int64_t unfold_num = SplitTupleInputsForInsertType(func_graph, input, &plant_inputs);
625   MS_LOG(DEBUG) << "Transform tuple unfold input: " << input->fullname_with_scope() << " to " << unfold_num
626                 << " inputs.";
627   return plant_inputs;
628 }
629 
ProcessTupleUnfoldToTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)630 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
631                                                                 const CNodePtr &node, bool *) {
632   MS_EXCEPTION_IF_NULL(func_graph);
633   MS_EXCEPTION_IF_NULL(input);
634   MS_EXCEPTION_IF_NULL(node);
635 
636   AnfNodePtrList result;
637   AnfNodePtr real_make_tuple_node = nullptr;
638   // If TupleUnfold input is a MakeTuple node, replace it with RealMakeTuple node.
639   if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
640     real_make_tuple_node = CreateRealMakeTupleByMakeTuple(func_graph, input->cast<CNodePtr>());
641   } else {
642     real_make_tuple_node = CreateRealMakeTupleByTupleUnfoldInput(func_graph, input);
643   }
644   result.push_back(real_make_tuple_node);
645   return result;
646 }
647 
ProcessTupleUnfoldToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)648 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph,
649                                                                  const AnfNodePtr &input, const CNodePtr &node,
650                                                                  bool *) {
651   MS_EXCEPTION_IF_NULL(func_graph);
652   MS_EXCEPTION_IF_NULL(input);
653   MS_EXCEPTION_IF_NULL(node);
654 
655   // Data type of the tensor should be set as an attr of TupleToTensor op.
656   size_t input_index = GetInputNodeIndex(input, node);
657   auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
658   // There might be nested tuples, we need to find one step further to get element's data type.
659   if (data_type == kObjectTypeTuple) {
660     auto seq_abs = input->abstract();
661     MS_EXCEPTION_IF_NULL(seq_abs);
662     if (!seq_abs->isa<abstract::AbstractSequence>()) {
663       MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
664     }
665     data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
666     MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
667   }
668   auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
669   // Use TupleToTensor op as the input of this node. Then TupleUnfoldToTuple pattern will be matched.
670   auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleToTensor->name()));
671   MS_EXCEPTION_IF_NULL(prim);
672   AnfNodePtrList inputs = {prim, input, type_id_value_node};
673   CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
674   MS_EXCEPTION_IF_NULL(tuple_to_tensor);
675   tuple_to_tensor->set_scope(input->scope());
676   // Set abstract for TupleToTensor op according to user node's input shape and type.
677   auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input, type_id_value_node});
678   MS_EXCEPTION_IF_NULL(abs);
679   MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
680   tuple_to_tensor->set_abstract(abs);
681   SetKernelInfoForNewCNode(tuple_to_tensor);
682   // Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
683   KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
684   MS_EXCEPTION_IF_NULL(tuple_to_tensor_build_info);
685   tuple_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE, KernelObjectType::SCALAR});
686   return {tuple_to_tensor};
687 }
688 
ProcessTupleToTupleUnfold(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)689 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
690                                                                 const CNodePtr &node, bool *new_prim) {
691   MS_EXCEPTION_IF_NULL(func_graph);
692   MS_EXCEPTION_IF_NULL(input);
693   MS_EXCEPTION_IF_NULL(node);
694 
695   // This pattern only supports user node is a TupleGetItem node.
696   // If this pattern is matched but the user node is not TupleGetItem, throw exception.
697   if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
698     // If this node supports any input types, do not process it.
699     KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
700     MS_EXCEPTION_IF_NULL(build_info);
701     if (build_info->op_type() == kernel::OpType::SKIP) {
702       return ProcessTupleToTupleUnfoldForSkipOp(func_graph, input, node, new_prim);
703     }
704     MS_LOG(EXCEPTION) << "Tuple to TupleUnfold pattern should have TupleGetItem as user node, but got "
705                       << node->fullname_with_scope() << ", " << node->DebugString();
706   }
707   return ProcessTupleToTupleUnfoldForTupleGetItem(func_graph, input, node, new_prim);
708 }
709 
ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)710 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr &func_graph,
711                                                                          const AnfNodePtr &input, const CNodePtr &node,
712                                                                          bool *new_prim) {
713   MS_EXCEPTION_IF_NULL(func_graph);
714   MS_EXCEPTION_IF_NULL(input);
715   MS_EXCEPTION_IF_NULL(node);
716   MS_EXCEPTION_IF_NULL(new_prim);
717   if (input->abstract() != nullptr && input->abstract()->isa<abstract::AbstractSequence>()) {
718     const auto &seq_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
719     MS_EXCEPTION_IF_NULL(seq_abs);
720     if (!seq_abs->dynamic_len()) {
721       AnfNodePtrList new_inputs;
722       for (const auto &node_input : node->inputs()) {
723         if (node_input != input) {
724           continue;
725         }
726         for (size_t i = 0; i < seq_abs->size(); ++i) {
727           CNodePtr get_item = CreatTupleGetItemNode(func_graph, input, i);
728           MS_EXCEPTION_IF_NULL(get_item);
729           auto kernel_info = std::make_shared<device::KernelInfo>();
730           MS_EXCEPTION_IF_NULL(kernel_info);
731           get_item->set_kernel_info(kernel_info);
732           auto builder = std::make_shared<KernelBuildInfoBuilder>();
733           MS_EXCEPTION_IF_NULL(builder);
734           AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), get_item.get());
735           KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(get_item);
736           MS_EXCEPTION_IF_NULL(build_info);
737           build_info->SetInputsFormat({AnfAlgo::GetOutputFormat(input, 0), kOpFormat_DEFAULT});
738           build_info->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(input, 0), TypeId::kNumberTypeInt64});
739           build_info->SetOutputsFormat({AnfAlgo::GetOutputFormat(input, 0)});
740           build_info->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(input, 0)});
741           build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE_UNFOLD, KernelObjectType::SCALAR});
742           build_info->SetOutputsKernelObjectType({KernelObjectType::TENSOR});
743           bool new_get_item_prim = false;
744           auto new_get_item_inputs =
745             ProcessTupleToTupleUnfoldForTupleGetItem(func_graph, input, get_item, &new_get_item_prim);
746           new_get_item_inputs.emplace_back(get_item->input(2));
747           auto new_get_item = CreateNewNode(func_graph, new_get_item_inputs, get_item);
748           MS_LOG(DEBUG) << "Create new node " << new_get_item->fullname_with_scope() << " "
749                         << new_get_item->DebugString(2) << " to replace " << get_item->fullname_with_scope() << " "
750                         << get_item->DebugString(2)
751                         << " build info:" << AnfAlgo::GetSelectKernelBuildInfo(new_get_item)->ToString();
752           new_inputs.emplace_back(new_get_item);
753         }
754       }
755       return new_inputs;
756     }
757   } else {
758     MS_LOG(WARNING) << "Invalid input:" << input->DebugString() << " for node:" << node->DebugString();
759   }
760   MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " skip TupleToTupleUnfold type matching.";
761   *new_prim = false;
762   return {input};
763 }
764 
ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)765 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr &func_graph,
766                                                                                const AnfNodePtr &input,
767                                                                                const CNodePtr &node, bool *new_prim) {
768   MS_EXCEPTION_IF_NULL(func_graph);
769   MS_EXCEPTION_IF_NULL(input);
770   MS_EXCEPTION_IF_NULL(node);
771   MS_EXCEPTION_IF_NULL(new_prim);
772   auto prim = NewValueNode(prim::kPrimRealTupleGetItem);
773   MS_EXCEPTION_IF_NULL(prim);
774   // Use original inputs except the primitive.
775   AnfNodePtrList new_inputs = {prim, input};
776 
777   // For TupleGetItem node, the second input value node's kernel info must be in case of nullptr.
778   if (common::AnfAlgo::GetInputTensorNum(node) != kSizeTwo) {
779     MS_LOG(EXCEPTION) << "Input number of TupleGetItem node " << node->DebugString() << " should be 2. But got "
780                       << common::AnfAlgo::GetInputTensorNum(node);
781   }
782   auto index_input = node->input(kIndex2);
783   MS_EXCEPTION_IF_NULL(index_input);
784   if (index_input->isa<ValueNode>()) {
785     SetKernelInfoForValueNode(index_input->cast<ValueNodePtr>());
786     // Because the index is used as real kernel RealTupleGetItem's second input, we must add TupleGetItem's index to
787     // kernel graph so that its device address will be allocated.
788     auto kg = func_graph->cast<KernelGraphPtr>();
789     MS_EXCEPTION_IF_NULL(kg);
790     MS_LOG(INFO) << "Add value:" << index_input->DebugString() << ", full name:" << index_input->fullname_with_scope()
791                  << " to kernel graph.";
792     kg->AddValueNodeToGraph(index_input->cast<ValueNodePtr>());
793   }
794 
795   auto abs = GenerateAbsByOpInfer(prim::kPrimRealTupleGetItem, {input, index_input});
796   MS_EXCEPTION_IF_NULL(abs);
797   MS_LOG(DEBUG) << "Abstract for RealTupleGetItem op is " << abs->ToString();
798   node->set_abstract(abs);
799 
800   // The primitive of user is changed.
801   *new_prim = true;
802   return new_inputs;
803 }
804 
ProcessTupleToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)805 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
806                                                            const CNodePtr &node, bool *) {
807   MS_EXCEPTION_IF_NULL(func_graph);
808   MS_EXCEPTION_IF_NULL(input);
809   MS_EXCEPTION_IF_NULL(node);
810   auto new_input = ConstructInputByValueNode(func_graph, input);
811   if (new_input != nullptr) {
812     MS_LOG(DEBUG) << "Create new value node:" << new_input->DebugString() << " by " << input->DebugString()
813                   << " for cnode:" << node->DebugString() << " in graph:" << func_graph->ToString();
814     return {new_input};
815   }
816 
817   if (IsNewKernel(node) && IsNewKernel(input)) {
818     MS_LOG(EXCEPTION) << "Insert TupleToTensor op for input:" << input->fullname_with_scope()
819                       << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
820   }
821 
822   // Data type of the tensor should be set as an attr of TupleToTensor op.
823   size_t input_index = GetInputNodeIndex(input, node);
824   auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
825   if (data_type == TypeId::kTypeUnknown && input->abstract() != nullptr &&
826       input->abstract()->isa<abstract::AbstractSequence>() &&
827       input->abstract()->cast<abstract::AbstractSequencePtr>()->elements().size() == 0) {
828     data_type = TypeId::kNumberTypeInt64;
829   }
830   // There might be nested tuples, we need to find one step further to get element's data type.
831   if (data_type == kObjectTypeTuple) {
832     auto seq_abs = input->abstract();
833     MS_EXCEPTION_IF_NULL(seq_abs);
834     if (!seq_abs->isa<abstract::AbstractSequence>()) {
835       MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
836     }
837     data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
838     MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
839   }
840   auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
841   // Simply insert TupleToTensor op between 'input' and 'node'.
842   auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleToTensor->name()));
843   MS_EXCEPTION_IF_NULL(prim);
844   AnfNodePtrList inputs = {prim, input, type_id_value_node};
845   CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
846   MS_EXCEPTION_IF_NULL(tuple_to_tensor);
847 
848   // Set abstract for TupleToTensor op according to user node's input shape and type.
849   auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input, type_id_value_node});
850   MS_EXCEPTION_IF_NULL(abs);
851   tuple_to_tensor->set_scope(input->scope());
852   MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
853   tuple_to_tensor->set_abstract(abs);
854   SetKernelInfoForNewCNode(tuple_to_tensor);
855   // Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
856   KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
857   MS_EXCEPTION_IF_NULL(tuple_to_tensor_build_info);
858   tuple_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE, KernelObjectType::SCALAR});
859   return {tuple_to_tensor};
860 }
861 
ProcessScalarToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)862 AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
863                                                             const CNodePtr &node, bool *) {
864   MS_EXCEPTION_IF_NULL(func_graph);
865   MS_EXCEPTION_IF_NULL(input);
866   MS_EXCEPTION_IF_NULL(node);
867   if (IsNewKernel(node) && IsNewKernel(input)) {
868     MS_LOG(EXCEPTION) << "Insert ScalarToTensor op for input:" << input->fullname_with_scope()
869                       << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
870   }
871 
872   auto new_input = ConstructInputByValueNode(func_graph, input);
873   if (new_input != nullptr) {
874     MS_LOG(DEBUG) << "Create new value node:" << new_input->DebugString() << " by " << input->DebugString()
875                   << " for cnode:" << node->DebugString() << " in graph:" << func_graph->ToString();
876     return {new_input};
877   }
878 
879   // Data type of the tensor should be set as an attr of ScalarToTensor op.
880   size_t input_index = GetInputNodeIndex(input, node);
881   auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
882   auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
883   // Simply insert ScalarToTensor op between 'input' and 'node'.
884   auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimScalarToTensor->name()));
885   MS_EXCEPTION_IF_NULL(prim);
886   AnfNodePtrList inputs = {prim, input, type_id_value_node};
887   CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
888   MS_EXCEPTION_IF_NULL(scalar_to_tensor);
889   scalar_to_tensor->set_scope(input->scope());
890   auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(scalar_to_tensor), {input, type_id_value_node});
891   MS_EXCEPTION_IF_NULL(abs);
892   MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
893   scalar_to_tensor->set_abstract(abs);
894   SetKernelInfoForNewCNode(scalar_to_tensor);
895   // Set object type info
896   KernelBuildInfoPtr scalar_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(scalar_to_tensor);
897   MS_EXCEPTION_IF_NULL(scalar_to_tensor_build_info);
898   scalar_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::SCALAR, KernelObjectType::SCALAR});
899   return {scalar_to_tensor};
900 }
901 
ProcessTensorToTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)902 AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
903                                                            const CNodePtr &node, bool *) {
904   MS_EXCEPTION_IF_NULL(func_graph);
905   MS_EXCEPTION_IF_NULL(input);
906   MS_EXCEPTION_IF_NULL(node);
907   if (IsNewKernel(node) && IsNewKernel(input)) {
908     MS_LOG(EXCEPTION) << "Insert TensorToTuple op for input:" << input->fullname_with_scope()
909                       << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
910   }
911   // Create TensorToTuple op.
912   auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorToTuple->name()));
913   MS_EXCEPTION_IF_NULL(prim);
914   AnfNodePtrList inputs = {prim, input};
915   CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
916   MS_EXCEPTION_IF_NULL(tensor_to_tuple);
917   tensor_to_tuple->set_scope(input->scope());
918   auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_tuple), {input});
919   MS_EXCEPTION_IF_NULL(abs);
920   MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
921   tensor_to_tuple->set_abstract(abs);
922 
923   SetKernelInfoForNewCNode(tensor_to_tuple);
924   return {tensor_to_tuple};
925 }
926 
ProcessTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)927 AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
928                                                             const CNodePtr &node, bool *) {
929   MS_EXCEPTION_IF_NULL(func_graph);
930   MS_EXCEPTION_IF_NULL(input);
931   MS_EXCEPTION_IF_NULL(node);
932   if (IsNewKernel(node) && IsNewKernel(input)) {
933     MS_LOG(EXCEPTION) << "Insert TensorToScalar op for input:" << input->fullname_with_scope()
934                       << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
935   }
936   // Create TensorToScalar op.
937   auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorToScalar->name()));
938   MS_EXCEPTION_IF_NULL(prim);
939   AnfNodePtrList inputs = {prim, input};
940   CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
941   MS_EXCEPTION_IF_NULL(tensor_to_scalar);
942   tensor_to_scalar->set_scope(input->scope());
943   auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_scalar), {input});
944   MS_EXCEPTION_IF_NULL(abs);
945   MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
946   tensor_to_scalar->set_abstract(abs);
947 
948   SetKernelInfoForNewCNode(tensor_to_scalar);
949   return {tensor_to_scalar};
950 }
951 }  // namespace opt
952 }  // namespace mindspore
953