1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/pass/insert_type_transform_op.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include "abstract/ops/primitive_infer_map.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/convert_utils.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/anf.h"
28 #include "kernel/common_utils.h"
29 #include "kernel/framework_utils.h"
30 #include "ops/arithmetic_ops.h"
31 #include "ops/nn_ops.h"
32 #include "ops/sequence_ops.h"
33 #include "ops/framework_ops.h"
34 #include "ops/op_def.h"
35 #include "ops/op_utils.h"
36
37 namespace mindspore {
38 namespace opt {
39 namespace {
IsNewKernel(const AnfNodePtr & node)40 bool IsNewKernel(const AnfNodePtr &node) {
41 MS_EXCEPTION_IF_NULL(node);
42 if (!node->isa<CNode>() || common::AnfAlgo::IsCallNode(node)) {
43 return false;
44 }
45 const auto &primitive = GetCNodePrimitive(node);
46 MS_EXCEPTION_IF_NULL(primitive);
47 return mindspore::ops::GetOpDef(primitive->name()) != nullptr;
48 }
49 } // namespace
SplitTupleInputsForInsertType(const FuncGraphPtr & graph,const AnfNodePtr & tuple_input,std::vector<AnfNodePtr> * plant_inputs)50 int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
51 std::vector<AnfNodePtr> *plant_inputs) {
52 MS_EXCEPTION_IF_NULL(graph);
53 MS_EXCEPTION_IF_NULL(tuple_input);
54 MS_EXCEPTION_IF_NULL(plant_inputs);
55
56 if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
57 auto abs = tuple_input->abstract();
58 MS_EXCEPTION_IF_NULL(abs);
59 MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
60 return -1;
61 }
62
63 auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
64 if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
65 auto make_tuple = tuple_input->cast<CNodePtr>();
66 MS_EXCEPTION_IF_NULL(make_tuple);
67 size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
68 for (size_t j = 0; j < tuple_input_num; ++j) {
69 // using for graph kernel
70 auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
71 MS_EXCEPTION_IF_NULL(dyn_input_node);
72 // Handle tuple nested scenes.
73 if (dyn_input_node->isa<CNode>() && (common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple) ||
74 common::AnfAlgo::IsTupleOutput(dyn_input_node))) {
75 int64_t dyn_input_size = SplitTupleInputsForInsertType(graph, dyn_input_node, plant_inputs);
76 input_size += LongToSize(dyn_input_size);
77 continue;
78 }
79 (void)plant_inputs->emplace_back(dyn_input_node);
80 }
81 return SizeToLong(input_size);
82 }
83 for (size_t index = 0; index < input_size; ++index) {
84 auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
85 MS_LOG(DEBUG) << "Create TupleGetItem node " << dynamic_input_node->fullname_with_scope() << " for tuple node "
86 << tuple_input->fullname_with_scope();
87 // The virtual node's object types should be set.
88 SetKernelInfoForNewCNode(dynamic_input_node, false);
89 (void)plant_inputs->emplace_back(dynamic_input_node);
90 }
91 return SizeToLong(input_size);
92 }
93
CreateNewNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & input_list,const CNodePtr & origin_node)94 AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_list,
95 const CNodePtr &origin_node) {
96 MS_EXCEPTION_IF_NULL(func_graph);
97 MS_EXCEPTION_IF_NULL(origin_node);
98
99 auto new_cnode = NewCNode(input_list, func_graph, {origin_node});
100 MS_EXCEPTION_IF_NULL(new_cnode);
101 // This pass should not have new node whose abstract differs from the original node. So set the original node's
102 // abstract.
103 new_cnode->set_abstract(origin_node->abstract());
104 new_cnode->set_scope(origin_node->scope());
105 new_cnode->set_primal_attrs(origin_node->primal_attrs());
106 new_cnode->set_attrs(origin_node->attrs());
107 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
108 if (kernel_graph != nullptr) {
109 const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(origin_node);
110 if (front_node != nullptr && front_node->isa<CNode>()) {
111 const auto front_cnode = front_node->cast<CNodePtr>();
112 MS_EXCEPTION_IF_NULL(front_cnode);
113 MS_LOG(INFO) << "Add replace real kernel flag for front node:" << front_node->DebugString();
114 front_cnode->AddAttr(kAttrReplaceRealKernelInBackend, MakeValue(true));
115 }
116 kernel_graph->FrontBackendlMapUpdate(origin_node, new_cnode);
117 }
118
119 // Inherit from origin kernel build info.
120 KernelBuildInfoPtr origin_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
121 MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
122 auto new_kernel_builder = std::make_shared<KernelBuildInfoBuilder>(origin_kernel_build_info);
123 MS_EXCEPTION_IF_NULL(new_kernel_builder);
124
125 auto kernel_info = std::make_shared<device::KernelInfo>();
126 MS_EXCEPTION_IF_NULL(kernel_info);
127 new_cnode->set_kernel_info(kernel_info);
128 AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
129
130 // Need to reset new cnode's kernel build info because the inputs type and number could be changed after processing
131 // methods. Only reset input types.
132 auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
133 auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0));
134 if (IsPrimitiveEquals(new_prim, origin_prim) && !kernel::IsDynamicParamKernel(origin_prim->name()) &&
135 (origin_kernel_build_info->op_type() != kernel::OpType::SKIP)) {
136 SetKernelInfoForNewCNode(new_cnode, false);
137 } else {
138 SetKernelInfoForNewCNode(new_cnode, true);
139 }
140
141 // If the primitive is not changed, this means only inputs are updated. So inherit output from origin node.
142 if (IsPrimitiveEquals(new_prim, origin_prim)) {
143 KernelBuildInfoPtr new_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(new_cnode);
144 KernelBuildInfoPtr origin_node_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
145 new_node_build_info->SetOutputsFormat(origin_node_build_info->GetAllOutputFormats());
146 new_node_build_info->SetOutputsDeviceType(origin_node_build_info->GetAllOutputDeviceTypes());
147 new_node_build_info->SetOutputsKernelObjectType(origin_node_build_info->GetAllOutputKernelObjectTypes());
148 }
149
150 return new_cnode;
151 }
152
CreateRealMakeTupleByMakeTuple(const FuncGraphPtr & func_graph,const CNodePtr & make_tuple_node)153 AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple_node) {
154 MS_EXCEPTION_IF_NULL(func_graph);
155 MS_EXCEPTION_IF_NULL(make_tuple_node);
156
157 // Create RealMakeTuple node and inherit inputs and abstract from MakeTuple node.
158 AnfNodePtrList inputs = make_tuple_node->inputs();
159 auto prim = NewValueNode(prim::kPrimRealMakeTuple);
160 MS_EXCEPTION_IF_NULL(prim);
161 if (inputs.empty()) {
162 MS_LOG(EXCEPTION) << "inputs is empty.";
163 }
164 inputs[kIndex0] = prim;
165 CNodePtr real_make_tuple = func_graph->NewCNode(inputs);
166 MS_EXCEPTION_IF_NULL(real_make_tuple);
167 real_make_tuple->set_scope(make_tuple_node->scope());
168 real_make_tuple->set_abstract(make_tuple_node->abstract());
169
170 SetKernelInfoForNewCNode(real_make_tuple);
171
172 // RealMakeTuple's inputs must be scalar or tensor. To avoid failing to select kernel, we must override
173 // RealMakeTuple's KernelObjectTypes to TENSOR, which is created from MakeTuple.
174 KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
175 MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
176 auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes();
177 if (!std::all_of(inputs_obj_types.begin(), inputs_obj_types.end(),
178 [](const auto &obj_type) { return obj_type == KernelObjectType::TENSOR; }) &&
179 !std::all_of(inputs_obj_types.begin(), inputs_obj_types.end(),
180 [](const auto &obj_type) { return obj_type == KernelObjectType::SCALAR; })) {
181 auto new_obj_types = inputs_obj_types;
182 (void)std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(),
183 [](const auto &) { return KernelObjectType::TENSOR; });
184 real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types);
185 MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " "
186 << new_obj_types;
187 }
188 return real_make_tuple;
189 }
190
CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr & func_graph,const AnfNodePtr & node_with_tuple_unfold_output)191 AnfNodePtr CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr &func_graph,
192 const AnfNodePtr &node_with_tuple_unfold_output) {
193 MS_EXCEPTION_IF_NULL(func_graph);
194 MS_EXCEPTION_IF_NULL(node_with_tuple_unfold_output);
195
196 auto prim = NewValueNode(prim::kPrimRealMakeTuple);
197 MS_EXCEPTION_IF_NULL(prim);
198 AnfNodePtrList inputs = {prim, node_with_tuple_unfold_output};
199 CNodePtr real_make_tuple = func_graph->NewCNode(inputs);
200 MS_EXCEPTION_IF_NULL(real_make_tuple);
201 real_make_tuple->set_scope(node_with_tuple_unfold_output->scope());
202 // Inherit abstract from TupleUnfold output node.
203 real_make_tuple->set_abstract(node_with_tuple_unfold_output->abstract());
204
205 SetKernelInfoForNewCNode(real_make_tuple);
206
207 // Set object type to TupleUnfold so TupleUnfoldToTupleUnfold pattern will be matched.
208 KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
209 MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
210 real_make_tuple_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE_UNFOLD});
211
212 // Extend tuple_unfold inputs.
213 abstract::AbstractTuplePtr tuple_unfold_abs =
214 node_with_tuple_unfold_output->abstract()->cast<abstract::AbstractTuplePtr>();
215 MS_EXCEPTION_IF_NULL(tuple_unfold_abs);
216 auto builder = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
217 MS_EXCEPTION_IF_NULL(builder);
218 std::vector<std::string> inputs_format{tuple_unfold_abs->size(), builder->GetInputFormat(kIndex0)};
219 std::vector<TypeId> inputs_type{tuple_unfold_abs->size(), builder->GetInputDeviceType(kIndex0)};
220 builder->SetInputsFormat(inputs_format);
221 builder->SetInputsDeviceType(inputs_type);
222
223 return real_make_tuple;
224 }
225
IsBackOffOp(const CNodePtr & cnode)226 bool IsBackOffOp(const CNodePtr &cnode) {
227 std::vector<std::string> back_off_op_list = {prim::kPrimTupleToTensor->name(), prim::kPrimScalarToTensor->name(),
228 prim::kPrimTensorToTuple->name(), prim::kPrimTensorToScalar->name(),
229 prim::kPrimRealMakeTuple->name(), prim::kPrimRealTupleGetItem->name(),
230 prim::kPrimTupleSetItem->name()};
231 if (std::find(back_off_op_list.begin(), back_off_op_list.end(), common::AnfAlgo::GetCNodeName(cnode)) !=
232 back_off_op_list.end()) {
233 return true;
234 }
235 return false;
236 }
237
SetBackOffFlag(const KernelBuildInfoPtr & build_info,const CNodePtr & cnode)238 void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode) {
239 if (IsBackOffOp(cnode)) {
240 build_info->set_valid(false);
241 }
242 }
243
SetKernelInfoForNewCNode(const CNodePtr & cnode,bool set_format_type)244 void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
245 MS_EXCEPTION_IF_NULL(cnode);
246 // In some cases cnode is newly created and has no kernel info.
247 if (cnode->kernel_info() == nullptr ||
248 (!dynamic_cast<device::KernelInfo *>(cnode->kernel_info())->has_build_info())) {
249 auto kernel_info = std::make_shared<device::KernelInfo>();
250 MS_EXCEPTION_IF_NULL(kernel_info);
251 cnode->set_kernel_info(kernel_info);
252 auto builder = std::make_shared<KernelBuildInfoBuilder>();
253 MS_EXCEPTION_IF_NULL(builder);
254 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
255 }
256 KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
257 MS_EXCEPTION_IF_NULL(build_info);
258 MS_LOG(DEBUG) << "Start setting kernel info for cnode " << cnode->DebugString() << " " << cnode->fullname_with_scope()
259 << ",set_format_type: " << set_format_type;
260 // Set input and output object type for subsequent type matching process.
261 std::vector<KernelObjectType> input_obj_type;
262 std::vector<KernelObjectType> output_obj_type;
263 GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
264 build_info->SetInputsKernelObjectType(input_obj_type);
265 build_info->SetOutputsKernelObjectType(output_obj_type);
266
267 if (set_format_type) {
268 // Set input and output format.
269 std::vector<std::string> inputs_format;
270 std::vector<TypeId> inputs_type;
271 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
272 for (size_t input_index = 0; input_index < input_num; ++input_index) {
273 auto input_node = common::AnfAlgo::GetInputNode(cnode, input_index);
274 MS_EXCEPTION_IF_NULL(input_node);
275
276 auto real_input = common::AnfAlgo::VisitKernelWithReturnType(input_node, kIndex0);
277 auto real_input_node = real_input.first;
278 MS_EXCEPTION_IF_NULL(real_input_node);
279 auto output_index = real_input.second;
280 if (real_input_node->kernel_info() == nullptr) {
281 (void)inputs_format.emplace_back(kOpFormat_DEFAULT);
282 } else {
283 (void)inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input_node, output_index));
284 }
285 inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
286 }
287
288 std::vector<std::string> outputs_format;
289 std::vector<TypeId> outputs_type;
290 size_t output_num;
291 if (output_obj_type[kIndex0] == KernelObjectType::TUPLE_UNFOLD) {
292 output_num = AnfAlgo::GetOutputElementNum(cnode);
293 } else {
294 output_num = kSizeOne;
295 }
296 for (size_t output_index = 0; output_index < output_num; ++output_index) {
297 (void)outputs_format.emplace_back(GenerateOutputFormatForNewCNode(cnode));
298 outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
299 }
300
301 build_info->SetInputsFormat(inputs_format);
302 build_info->SetInputsDeviceType(inputs_type);
303 build_info->SetOutputsFormat(outputs_format);
304 build_info->SetOutputsDeviceType(outputs_type);
305 }
306
307 // The node may not be supported in the current device.
308 SetBackOffFlag(build_info, cnode);
309 MS_LOG(INFO) << "Set kernel info for cnode " << cnode->DebugString() << " " << cnode->fullname_with_scope() << " "
310 << build_info->ToString();
311 }
312
SetKernelInfoForValueNode(const ValueNodePtr & value_node)313 void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
314 MS_EXCEPTION_IF_NULL(value_node);
315 auto kernel_info = std::make_shared<device::KernelInfo>();
316 MS_EXCEPTION_IF_NULL(kernel_info);
317 value_node->set_kernel_info(kernel_info);
318 auto builder = std::make_shared<KernelBuildInfoBuilder>();
319 MS_EXCEPTION_IF_NULL(builder);
320
321 auto type_id = value_node->value()->type()->type_id();
322 std::vector<std::string> inputs_format = {kOpFormat_DEFAULT};
323 std::vector<TypeId> inputs_type = {type_id};
324 std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
325 std::vector<TypeId> outputs_type = {type_id};
326
327 auto abs_type = AnfAlgo::GetAbstractObjectType(value_node->abstract());
328 std::vector<KernelObjectType> input_obj_type = {kernel::TypeIdToKernelObjectType(abs_type)};
329 std::vector<KernelObjectType> output_obj_type = {kernel::TypeIdToKernelObjectType(abs_type)};
330
331 builder->SetInputsFormat(inputs_format);
332 builder->SetInputsDeviceType(inputs_type);
333 builder->SetOutputsFormat(outputs_format);
334 builder->SetOutputsDeviceType(outputs_type);
335 builder->SetInputsKernelObjectType(input_obj_type);
336 builder->SetOutputsKernelObjectType(output_obj_type);
337 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
338 }
339
GenerateAbsByOpInfer(const PrimitivePtr & primitive,const AnfNodePtrList & input_list)340 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
341 MS_EXCEPTION_IF_NULL(primitive);
342 std::vector<AbstractBasePtr> input_args;
343 (void)std::for_each(input_list.begin(), input_list.end(),
344 [&input_args](const auto &input) { (void)input_args.emplace_back(input->abstract()); });
345 auto abs_opt = abstract::TryInferAbstract(primitive, input_args);
346 if (!abs_opt.has_value()) {
347 MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
348 }
349 auto abs = abs_opt.value();
350 MS_EXCEPTION_IF_NULL(abs);
351 MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
352 return abs;
353 }
354
GenerateOutputFormatForNewCNode(const CNodePtr & cnode)355 std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode) {
356 MS_EXCEPTION_IF_NULL(cnode);
357 if (IsPrimitiveCNode(cnode, prim::kPrimRealMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimTupleToTensor)) {
358 // We take first input format as the output format because multiple types and formats of
359 // RealMakeTuple/TupleToTensor are not supported.
360 std::string represent_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, kIndex0);
361 return represent_format;
362 }
363 return kOpFormat_DEFAULT;
364 }
365
GenerateKernelObjectTypeForNewCNode(const CNodePtr & cnode,std::vector<KernelObjectType> * input_obj_type,std::vector<KernelObjectType> * output_obj_type)366 void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type,
367 std::vector<KernelObjectType> *output_obj_type) {
368 MS_EXCEPTION_IF_NULL(cnode);
369 MS_EXCEPTION_IF_NULL(input_obj_type);
370 MS_EXCEPTION_IF_NULL(output_obj_type);
371
372 // Simply trasverse all inputs and get their object types.
373 // But if the input's object type is not set, this will throw exception so must pay attention when using this
374 // function.
375 auto general_input_obj_type_func = [&]() {
376 for (size_t i = kIndex1; i < cnode->size(); i++) {
377 auto input_node = cnode->input(i);
378 MS_EXCEPTION_IF_NULL(input_node);
379 // Set input kernel object type as input node's output kernel object type.
380 if (input_node->kernel_info() == nullptr ||
381 (!dynamic_cast<device::KernelInfo *>(input_node->kernel_info())->has_build_info())) {
382 auto abs_type = AnfAlgo::GetAbstractObjectType(input_node->abstract());
383 input_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
384 } else {
385 input_obj_type->push_back(AnfAlgo::GetOutputKernelObjectType(input_node, kIndex0));
386 }
387 }
388 };
389
390 if (IsPrimitiveCNode(cnode, prim::kPrimRealMakeTuple)) {
391 general_input_obj_type_func();
392 output_obj_type->push_back(KernelObjectType::TUPLE);
393 } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleToTensor)) {
394 general_input_obj_type_func();
395 output_obj_type->push_back(KernelObjectType::TENSOR);
396 } else if (IsPrimitiveCNode(cnode, prim::kPrimTensorToTuple)) {
397 general_input_obj_type_func();
398 output_obj_type->push_back(KernelObjectType::TUPLE);
399 } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
400 // First input of TupleGetItem must be TUPLE_UNFOLD.
401 // Second is the index.
402 *input_obj_type = {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR};
403 // Get actual output type of TupleGetItem node.
404 auto abs_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
405 output_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
406 } else if (IsPrimitiveCNode(cnode, prim::kPrimRealTupleGetItem)) {
407 general_input_obj_type_func();
408 // Get actual output type of RealTupleGetItem node.
409 auto abs_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
410 output_obj_type->push_back(kernel::TypeIdToKernelObjectType(abs_type));
411 } else if (IsPrimitiveCNode(cnode, prim::kPrimTensorToScalar)) {
412 general_input_obj_type_func();
413 output_obj_type->push_back(KernelObjectType::SCALAR);
414 } else {
415 // For other ops, set TENSOR as output object type by default.
416 general_input_obj_type_func();
417 output_obj_type->push_back(KernelObjectType::TENSOR);
418 }
419
420 MS_LOG(DEBUG) << "Generate input and output object types for new node " << cnode->fullname_with_scope() << " "
421 << cnode->DebugString() << ". Input object types: " << *input_obj_type
422 << ". Output object types: " << *output_obj_type;
423 }
424
ConstructInputByValueNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input)425 AnfNodePtr ConstructInputByValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
426 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
427 if (kernel_graph == nullptr || (!input->isa<ValueNode>())) {
428 return nullptr;
429 }
430 const auto &value_node = input->cast<ValueNodePtr>();
431 MS_EXCEPTION_IF_NULL(value_node);
432 const auto &value = value_node->value();
433 MS_EXCEPTION_IF_NULL(value);
434 if (value->isa<Scalar>()) {
435 return CreateTensorInput(kernel_graph, input);
436 }
437 if (!value->isa<ValueSequence>()) {
438 return nullptr;
439 }
440 const auto &value_sequence = value->cast<ValueSequencePtr>();
441 MS_EXCEPTION_IF_NULL(value_sequence);
442 if (value_sequence->size() == 0) {
443 return nullptr;
444 }
445 const auto &value0 = value_sequence->value()[0];
446 if (value0 == nullptr || (!value0->isa<Scalar>())) {
447 return nullptr;
448 }
449 const auto &scalar0 = value0->cast<ScalarPtr>();
450 MS_EXCEPTION_IF_NULL(scalar0);
451 const auto &type0 = scalar0->type();
452 MS_EXCEPTION_IF_NULL(type0);
453 const auto &type_id0 = type0->type_id();
454 if (std::any_of(value_sequence->value().begin() + 1, value_sequence->value().end(),
455 [type_id0](const ValuePtr &value) {
456 return value == nullptr || (!value->isa<Scalar>()) || value->cast<ScalarPtr>()->type() == nullptr ||
457 value->cast<ScalarPtr>()->type()->type_id() != type_id0;
458 })) {
459 return nullptr;
460 }
461 return CreateTensorInput(kernel_graph, input);
462 }
463
464 // A map of kernel object type pairs to processing functions.
465 static std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
466
467 // The nodes of which object types should be handled.
468 const std::vector<PrimitivePtr> need_handled_types = {prim::kPrimMakeTuple, prim::kPrimTupleGetItem};
469
InsertTypeTransformOp(bool multigraph)470 InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
471 : PatternProcessPass("insert_type_transform_op", multigraph) {
472 kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
473 std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
474 std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
475 kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE}] =
476 std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTuple, this, std::placeholders::_1, std::placeholders::_2,
477 std::placeholders::_3, std::placeholders::_4);
478 kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR}] =
479 std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTensor, this, std::placeholders::_1, std::placeholders::_2,
480 std::placeholders::_3, std::placeholders::_4);
481 kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TUPLE_UNFOLD}] =
482 std::bind(&InsertTypeTransformOp::ProcessTupleToTupleUnfold, this, std::placeholders::_1, std::placeholders::_2,
483 std::placeholders::_3, std::placeholders::_4);
484 kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TENSOR}] =
485 std::bind(&InsertTypeTransformOp::ProcessTupleToTensor, this, std::placeholders::_1, std::placeholders::_2,
486 std::placeholders::_3, std::placeholders::_4);
487 kTypePairToProcessFunc[{KernelObjectType::SCALAR, KernelObjectType::TENSOR}] =
488 std::bind(&InsertTypeTransformOp::ProcessScalarToTensor, this, std::placeholders::_1, std::placeholders::_2,
489 std::placeholders::_3, std::placeholders::_4);
490 kTypePairToProcessFunc[{KernelObjectType::TENSOR, KernelObjectType::TUPLE}] =
491 std::bind(&InsertTypeTransformOp::ProcessTensorToTuple, this, std::placeholders::_1, std::placeholders::_2,
492 std::placeholders::_3, std::placeholders::_4);
493 kTypePairToProcessFunc[{KernelObjectType::TENSOR, KernelObjectType::SCALAR}] =
494 std::bind(&InsertTypeTransformOp::ProcessTensorToScalar, this, std::placeholders::_1, std::placeholders::_2,
495 std::placeholders::_3, std::placeholders::_4);
496 }
497
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const498 const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
499 const EquivPtr &) const {
500 MS_EXCEPTION_IF_NULL(func_graph);
501 MS_EXCEPTION_IF_NULL(node);
502 if (!node->isa<CNode>()) {
503 return nullptr;
504 }
505 if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
506 return nullptr;
507 }
508 if ((node->kernel_info() == nullptr) ||
509 (!dynamic_cast<device::KernelInfo *>(node->kernel_info())->has_build_info()) ||
510 (common::AnfAlgo::GetCNodeName(node) == "MakeTuple")) {
511 return nullptr;
512 }
513
514 auto cnode = node->cast<CNodePtr>();
515 MS_EXCEPTION_IF_NULL(cnode);
516 AnfNodePtrList new_input_list = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
517 // If kernel object types are matched, set this flag to true and new node will be created to replace original
518 // node.
519 bool matched = false;
520 for (size_t i = 0; i < common::AnfAlgo::GetInputNum(cnode); ++i) {
521 const auto &input_node = common::AnfAlgo::GetInputNode(cnode, i);
522 // Skip for monad input.
523 if (HasAbstractMonad(input_node) || (node->kernel_info() == nullptr) ||
524 !dynamic_cast<device::KernelInfo *>(node->kernel_info())) {
525 new_input_list.push_back(input_node);
526 continue;
527 }
528
529 const auto &real_input_node =
530 common::AnfAlgo::VisitKernelWithReturnType(input_node, kIndex0, false, need_handled_types).first;
531 MS_EXCEPTION_IF_NULL(real_input_node);
532 if ((real_input_node->kernel_info() == nullptr) ||
533 (!dynamic_cast<device::KernelInfo *>(real_input_node->kernel_info())->has_build_info())) {
534 MS_LOG(DEBUG) << node->fullname_with_scope() << " input index:" << i
535 << ", input node:" << real_input_node->fullname_with_scope() << " doesn't have build info.";
536 new_input_list.push_back(input_node);
537 continue;
538 }
539
540 auto needed_input_type = AnfAlgo::GetInputKernelObjectType(node, i);
541 auto current_input_type = AnfAlgo::GetOutputKernelObjectType(real_input_node, kIndex0);
542 if ((kObjectTypeToString.count(needed_input_type) == 0) || (kObjectTypeToString.count(current_input_type) == 0)) {
543 MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type "
544 << needed_input_type << " is not valid for node " << node->fullname_with_scope()
545 << " input index:" << i << ", input node:" << real_input_node->fullname_with_scope();
546 }
547 MS_LOG(DEBUG) << "The current input object type:" << kObjectTypeToString[current_input_type]
548 << ", needed input object type:" << kObjectTypeToString[needed_input_type]
549 << " for node:" << node->fullname_with_scope() << " input index:" << i
550 << ", input node:" << real_input_node->fullname_with_scope();
551
552 ObjectTypePair type_pair = {current_input_type, needed_input_type};
553 if (kTypePairToProcessFunc.count(type_pair) != 0) {
554 MS_LOG(INFO) << "Kernel object type pair of input index " << i << " for node pair "
555 << input_node->fullname_with_scope() << " to " << cnode->fullname_with_scope() << " is "
556 << type_pair.to_string();
557 bool new_prim = false;
558 AnfNodePtrList processed_input_list = kTypePairToProcessFunc[type_pair](func_graph, input_node, cnode, &new_prim);
559 if (IsInputUpdated(input_node, processed_input_list)) {
560 matched = true;
561 }
562 if (new_prim) {
563 MS_LOG(DEBUG) << "New primtive is " << processed_input_list[kIndex0]->fullname_with_scope() << " to replace "
564 << new_input_list[kIndex0]->fullname_with_scope();
565 // If new primitive is created, replace the old one, which is the first element of the input list.
566 new_input_list[kIndex0] = processed_input_list[kIndex0];
567 // Jump the primitive node the first one, and the rest is the new inputs.
568 (void)new_input_list.insert(new_input_list.end(), std::begin(processed_input_list) + kIndex1,
569 processed_input_list.end());
570 } else {
571 (void)new_input_list.insert(new_input_list.end(), processed_input_list.begin(), processed_input_list.end());
572 }
573 } else {
574 // If this input type is valid, just push back the origin input.
575 new_input_list.push_back(input_node);
576 }
577 }
578
579 if (matched) {
580 // Create replacing node, update front-end node map, set kernel build info, inherit attributes, etc. These
581 // operations could rely on the origin CNode.
582 auto new_node = CreateNewNode(func_graph, new_input_list, cnode);
583 MS_LOG(INFO) << "Create new node " << new_node->fullname_with_scope() << " " << new_node->DebugString()
584 << " to replace " << cnode->fullname_with_scope() << " " << cnode->DebugString();
585 return new_node;
586 }
587 return nullptr;
588 }
589
IsInputUpdated(const AnfNodePtr & origin_input,const AnfNodePtrList & new_input_list) const590 bool InsertTypeTransformOp::IsInputUpdated(const AnfNodePtr &origin_input, const AnfNodePtrList &new_input_list) const {
591 MS_EXCEPTION_IF_NULL(origin_input);
592 if (new_input_list.empty()) {
593 MS_LOG(INFO) << "The new input list size should be at least 1, but got 0.";
594 return false;
595 }
596
597 if (new_input_list.size() == kSizeOne && new_input_list[kIndex0] == origin_input) {
598 MS_LOG(DEBUG) << "Input node " << origin_input->fullname_with_scope() << " " << origin_input->DebugString()
599 << " should not be updated.";
600 return false;
601 }
602 MS_LOG(DEBUG) << "Input node " << origin_input->fullname_with_scope() << " " << origin_input->DebugString()
603 << " will be replaced.";
604 return true;
605 }
606
ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)607 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr &func_graph,
608 const AnfNodePtr &input, const CNodePtr &node,
609 bool *) {
610 MS_EXCEPTION_IF_NULL(func_graph);
611 MS_EXCEPTION_IF_NULL(input);
612 MS_EXCEPTION_IF_NULL(node);
613
614 // If the input needs to be skipped as ConvertTupleInputToDynamicInput does, return the input node itself for
615 // caller to construct input list.
616 bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
617 bool skip = (is_bprop_cut && input->abstract()->isa<abstract::AbstractSparseTensor>()) ||
618 IsPrimitiveCNode(node, prim::kPrimTupleGetItem);
619 if (skip) {
620 return {input};
621 }
622
623 AnfNodePtrList plant_inputs;
624 int64_t unfold_num = SplitTupleInputsForInsertType(func_graph, input, &plant_inputs);
625 MS_LOG(DEBUG) << "Transform tuple unfold input: " << input->fullname_with_scope() << " to " << unfold_num
626 << " inputs.";
627 return plant_inputs;
628 }
629
ProcessTupleUnfoldToTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)630 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
631 const CNodePtr &node, bool *) {
632 MS_EXCEPTION_IF_NULL(func_graph);
633 MS_EXCEPTION_IF_NULL(input);
634 MS_EXCEPTION_IF_NULL(node);
635
636 AnfNodePtrList result;
637 AnfNodePtr real_make_tuple_node = nullptr;
638 // If TupleUnfold input is a MakeTuple node, replace it with RealMakeTuple node.
639 if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
640 real_make_tuple_node = CreateRealMakeTupleByMakeTuple(func_graph, input->cast<CNodePtr>());
641 } else {
642 real_make_tuple_node = CreateRealMakeTupleByTupleUnfoldInput(func_graph, input);
643 }
644 result.push_back(real_make_tuple_node);
645 return result;
646 }
647
ProcessTupleUnfoldToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)648 AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph,
649 const AnfNodePtr &input, const CNodePtr &node,
650 bool *) {
651 MS_EXCEPTION_IF_NULL(func_graph);
652 MS_EXCEPTION_IF_NULL(input);
653 MS_EXCEPTION_IF_NULL(node);
654
655 // Data type of the tensor should be set as an attr of TupleToTensor op.
656 size_t input_index = GetInputNodeIndex(input, node);
657 auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
658 // There might be nested tuples, we need to find one step further to get element's data type.
659 if (data_type == kObjectTypeTuple) {
660 auto seq_abs = input->abstract();
661 MS_EXCEPTION_IF_NULL(seq_abs);
662 if (!seq_abs->isa<abstract::AbstractSequence>()) {
663 MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
664 }
665 data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
666 MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
667 }
668 auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
669 // Use TupleToTensor op as the input of this node. Then TupleUnfoldToTuple pattern will be matched.
670 auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleToTensor->name()));
671 MS_EXCEPTION_IF_NULL(prim);
672 AnfNodePtrList inputs = {prim, input, type_id_value_node};
673 CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
674 MS_EXCEPTION_IF_NULL(tuple_to_tensor);
675 tuple_to_tensor->set_scope(input->scope());
676 // Set abstract for TupleToTensor op according to user node's input shape and type.
677 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input, type_id_value_node});
678 MS_EXCEPTION_IF_NULL(abs);
679 MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
680 tuple_to_tensor->set_abstract(abs);
681 SetKernelInfoForNewCNode(tuple_to_tensor);
682 // Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
683 KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
684 MS_EXCEPTION_IF_NULL(tuple_to_tensor_build_info);
685 tuple_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE, KernelObjectType::SCALAR});
686 return {tuple_to_tensor};
687 }
688
ProcessTupleToTupleUnfold(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)689 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
690 const CNodePtr &node, bool *new_prim) {
691 MS_EXCEPTION_IF_NULL(func_graph);
692 MS_EXCEPTION_IF_NULL(input);
693 MS_EXCEPTION_IF_NULL(node);
694
695 // This pattern only supports user node is a TupleGetItem node.
696 // If this pattern is matched but the user node is not TupleGetItem, throw exception.
697 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
698 // If this node supports any input types, do not process it.
699 KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
700 MS_EXCEPTION_IF_NULL(build_info);
701 if (build_info->op_type() == kernel::OpType::SKIP) {
702 return ProcessTupleToTupleUnfoldForSkipOp(func_graph, input, node, new_prim);
703 }
704 MS_LOG(EXCEPTION) << "Tuple to TupleUnfold pattern should have TupleGetItem as user node, but got "
705 << node->fullname_with_scope() << ", " << node->DebugString();
706 }
707 return ProcessTupleToTupleUnfoldForTupleGetItem(func_graph, input, node, new_prim);
708 }
709
ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)710 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr &func_graph,
711 const AnfNodePtr &input, const CNodePtr &node,
712 bool *new_prim) {
713 MS_EXCEPTION_IF_NULL(func_graph);
714 MS_EXCEPTION_IF_NULL(input);
715 MS_EXCEPTION_IF_NULL(node);
716 MS_EXCEPTION_IF_NULL(new_prim);
717 if (input->abstract() != nullptr && input->abstract()->isa<abstract::AbstractSequence>()) {
718 const auto &seq_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
719 MS_EXCEPTION_IF_NULL(seq_abs);
720 if (!seq_abs->dynamic_len()) {
721 AnfNodePtrList new_inputs;
722 for (const auto &node_input : node->inputs()) {
723 if (node_input != input) {
724 continue;
725 }
726 for (size_t i = 0; i < seq_abs->size(); ++i) {
727 CNodePtr get_item = CreatTupleGetItemNode(func_graph, input, i);
728 MS_EXCEPTION_IF_NULL(get_item);
729 auto kernel_info = std::make_shared<device::KernelInfo>();
730 MS_EXCEPTION_IF_NULL(kernel_info);
731 get_item->set_kernel_info(kernel_info);
732 auto builder = std::make_shared<KernelBuildInfoBuilder>();
733 MS_EXCEPTION_IF_NULL(builder);
734 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), get_item.get());
735 KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(get_item);
736 MS_EXCEPTION_IF_NULL(build_info);
737 build_info->SetInputsFormat({AnfAlgo::GetOutputFormat(input, 0), kOpFormat_DEFAULT});
738 build_info->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(input, 0), TypeId::kNumberTypeInt64});
739 build_info->SetOutputsFormat({AnfAlgo::GetOutputFormat(input, 0)});
740 build_info->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(input, 0)});
741 build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE_UNFOLD, KernelObjectType::SCALAR});
742 build_info->SetOutputsKernelObjectType({KernelObjectType::TENSOR});
743 bool new_get_item_prim = false;
744 auto new_get_item_inputs =
745 ProcessTupleToTupleUnfoldForTupleGetItem(func_graph, input, get_item, &new_get_item_prim);
746 new_get_item_inputs.emplace_back(get_item->input(2));
747 auto new_get_item = CreateNewNode(func_graph, new_get_item_inputs, get_item);
748 MS_LOG(DEBUG) << "Create new node " << new_get_item->fullname_with_scope() << " "
749 << new_get_item->DebugString(2) << " to replace " << get_item->fullname_with_scope() << " "
750 << get_item->DebugString(2)
751 << " build info:" << AnfAlgo::GetSelectKernelBuildInfo(new_get_item)->ToString();
752 new_inputs.emplace_back(new_get_item);
753 }
754 }
755 return new_inputs;
756 }
757 } else {
758 MS_LOG(WARNING) << "Invalid input:" << input->DebugString() << " for node:" << node->DebugString();
759 }
760 MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " skip TupleToTupleUnfold type matching.";
761 *new_prim = false;
762 return {input};
763 }
764
ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool * new_prim)765 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr &func_graph,
766 const AnfNodePtr &input,
767 const CNodePtr &node, bool *new_prim) {
768 MS_EXCEPTION_IF_NULL(func_graph);
769 MS_EXCEPTION_IF_NULL(input);
770 MS_EXCEPTION_IF_NULL(node);
771 MS_EXCEPTION_IF_NULL(new_prim);
772 auto prim = NewValueNode(prim::kPrimRealTupleGetItem);
773 MS_EXCEPTION_IF_NULL(prim);
774 // Use original inputs except the primitive.
775 AnfNodePtrList new_inputs = {prim, input};
776
777 // For TupleGetItem node, the second input value node's kernel info must be in case of nullptr.
778 if (common::AnfAlgo::GetInputTensorNum(node) != kSizeTwo) {
779 MS_LOG(EXCEPTION) << "Input number of TupleGetItem node " << node->DebugString() << " should be 2. But got "
780 << common::AnfAlgo::GetInputTensorNum(node);
781 }
782 auto index_input = node->input(kIndex2);
783 MS_EXCEPTION_IF_NULL(index_input);
784 if (index_input->isa<ValueNode>()) {
785 SetKernelInfoForValueNode(index_input->cast<ValueNodePtr>());
786 // Because the index is used as real kernel RealTupleGetItem's second input, we must add TupleGetItem's index to
787 // kernel graph so that its device address will be allocated.
788 auto kg = func_graph->cast<KernelGraphPtr>();
789 MS_EXCEPTION_IF_NULL(kg);
790 MS_LOG(INFO) << "Add value:" << index_input->DebugString() << ", full name:" << index_input->fullname_with_scope()
791 << " to kernel graph.";
792 kg->AddValueNodeToGraph(index_input->cast<ValueNodePtr>());
793 }
794
795 auto abs = GenerateAbsByOpInfer(prim::kPrimRealTupleGetItem, {input, index_input});
796 MS_EXCEPTION_IF_NULL(abs);
797 MS_LOG(DEBUG) << "Abstract for RealTupleGetItem op is " << abs->ToString();
798 node->set_abstract(abs);
799
800 // The primitive of user is changed.
801 *new_prim = true;
802 return new_inputs;
803 }
804
ProcessTupleToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)805 AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
806 const CNodePtr &node, bool *) {
807 MS_EXCEPTION_IF_NULL(func_graph);
808 MS_EXCEPTION_IF_NULL(input);
809 MS_EXCEPTION_IF_NULL(node);
810 auto new_input = ConstructInputByValueNode(func_graph, input);
811 if (new_input != nullptr) {
812 MS_LOG(DEBUG) << "Create new value node:" << new_input->DebugString() << " by " << input->DebugString()
813 << " for cnode:" << node->DebugString() << " in graph:" << func_graph->ToString();
814 return {new_input};
815 }
816
817 if (IsNewKernel(node) && IsNewKernel(input)) {
818 MS_LOG(EXCEPTION) << "Insert TupleToTensor op for input:" << input->fullname_with_scope()
819 << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
820 }
821
822 // Data type of the tensor should be set as an attr of TupleToTensor op.
823 size_t input_index = GetInputNodeIndex(input, node);
824 auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
825 if (data_type == TypeId::kTypeUnknown && input->abstract() != nullptr &&
826 input->abstract()->isa<abstract::AbstractSequence>() &&
827 input->abstract()->cast<abstract::AbstractSequencePtr>()->elements().size() == 0) {
828 data_type = TypeId::kNumberTypeInt64;
829 }
830 // There might be nested tuples, we need to find one step further to get element's data type.
831 if (data_type == kObjectTypeTuple) {
832 auto seq_abs = input->abstract();
833 MS_EXCEPTION_IF_NULL(seq_abs);
834 if (!seq_abs->isa<abstract::AbstractSequence>()) {
835 MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
836 }
837 data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
838 MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
839 }
840 auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
841 // Simply insert TupleToTensor op between 'input' and 'node'.
842 auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleToTensor->name()));
843 MS_EXCEPTION_IF_NULL(prim);
844 AnfNodePtrList inputs = {prim, input, type_id_value_node};
845 CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
846 MS_EXCEPTION_IF_NULL(tuple_to_tensor);
847
848 // Set abstract for TupleToTensor op according to user node's input shape and type.
849 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input, type_id_value_node});
850 MS_EXCEPTION_IF_NULL(abs);
851 tuple_to_tensor->set_scope(input->scope());
852 MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
853 tuple_to_tensor->set_abstract(abs);
854 SetKernelInfoForNewCNode(tuple_to_tensor);
855 // Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
856 KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
857 MS_EXCEPTION_IF_NULL(tuple_to_tensor_build_info);
858 tuple_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE, KernelObjectType::SCALAR});
859 return {tuple_to_tensor};
860 }
861
ProcessScalarToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)862 AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
863 const CNodePtr &node, bool *) {
864 MS_EXCEPTION_IF_NULL(func_graph);
865 MS_EXCEPTION_IF_NULL(input);
866 MS_EXCEPTION_IF_NULL(node);
867 if (IsNewKernel(node) && IsNewKernel(input)) {
868 MS_LOG(EXCEPTION) << "Insert ScalarToTensor op for input:" << input->fullname_with_scope()
869 << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
870 }
871
872 auto new_input = ConstructInputByValueNode(func_graph, input);
873 if (new_input != nullptr) {
874 MS_LOG(DEBUG) << "Create new value node:" << new_input->DebugString() << " by " << input->DebugString()
875 << " for cnode:" << node->DebugString() << " in graph:" << func_graph->ToString();
876 return {new_input};
877 }
878
879 // Data type of the tensor should be set as an attr of ScalarToTensor op.
880 size_t input_index = GetInputNodeIndex(input, node);
881 auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
882 auto type_id_value_node = AnfAlgo::CreateTypeIdValueNodeToKernelGraph(func_graph, data_type);
883 // Simply insert ScalarToTensor op between 'input' and 'node'.
884 auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimScalarToTensor->name()));
885 MS_EXCEPTION_IF_NULL(prim);
886 AnfNodePtrList inputs = {prim, input, type_id_value_node};
887 CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
888 MS_EXCEPTION_IF_NULL(scalar_to_tensor);
889 scalar_to_tensor->set_scope(input->scope());
890 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(scalar_to_tensor), {input, type_id_value_node});
891 MS_EXCEPTION_IF_NULL(abs);
892 MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
893 scalar_to_tensor->set_abstract(abs);
894 SetKernelInfoForNewCNode(scalar_to_tensor);
895 // Set object type info
896 KernelBuildInfoPtr scalar_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(scalar_to_tensor);
897 MS_EXCEPTION_IF_NULL(scalar_to_tensor_build_info);
898 scalar_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::SCALAR, KernelObjectType::SCALAR});
899 return {scalar_to_tensor};
900 }
901
ProcessTensorToTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)902 AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
903 const CNodePtr &node, bool *) {
904 MS_EXCEPTION_IF_NULL(func_graph);
905 MS_EXCEPTION_IF_NULL(input);
906 MS_EXCEPTION_IF_NULL(node);
907 if (IsNewKernel(node) && IsNewKernel(input)) {
908 MS_LOG(EXCEPTION) << "Insert TensorToTuple op for input:" << input->fullname_with_scope()
909 << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
910 }
911 // Create TensorToTuple op.
912 auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorToTuple->name()));
913 MS_EXCEPTION_IF_NULL(prim);
914 AnfNodePtrList inputs = {prim, input};
915 CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
916 MS_EXCEPTION_IF_NULL(tensor_to_tuple);
917 tensor_to_tuple->set_scope(input->scope());
918 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_tuple), {input});
919 MS_EXCEPTION_IF_NULL(abs);
920 MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
921 tensor_to_tuple->set_abstract(abs);
922
923 SetKernelInfoForNewCNode(tensor_to_tuple);
924 return {tensor_to_tuple};
925 }
926
ProcessTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const CNodePtr & node,bool *)927 AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
928 const CNodePtr &node, bool *) {
929 MS_EXCEPTION_IF_NULL(func_graph);
930 MS_EXCEPTION_IF_NULL(input);
931 MS_EXCEPTION_IF_NULL(node);
932 if (IsNewKernel(node) && IsNewKernel(input)) {
933 MS_LOG(EXCEPTION) << "Insert TensorToScalar op for input:" << input->fullname_with_scope()
934 << " of node:" << node->fullname_with_scope() << " in graph:" << func_graph->ToString();
935 }
936 // Create TensorToScalar op.
937 auto prim = NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorToScalar->name()));
938 MS_EXCEPTION_IF_NULL(prim);
939 AnfNodePtrList inputs = {prim, input};
940 CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
941 MS_EXCEPTION_IF_NULL(tensor_to_scalar);
942 tensor_to_scalar->set_scope(input->scope());
943 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_scalar), {input});
944 MS_EXCEPTION_IF_NULL(abs);
945 MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
946 tensor_to_scalar->set_abstract(abs);
947
948 SetKernelInfoForNewCNode(tensor_to_scalar);
949 return {tensor_to_scalar};
950 }
951 } // namespace opt
952 } // namespace mindspore
953