• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <memory>
17 #include <vector>
18 #include <string>
19 #include "ops/auto_generate/gen_lite_ops.h"
20 #include "ops/auto_generate/gen_ops_primitive.h"
21 #include "ops/array_ops.h"
22 #include "ops/lite_ops.h"
23 #include "tools/optimizer/graph/scalar_op_pass.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "mindspore/core/ops/arithmetic_ops.h"
26 #include "tools/optimizer/graph/lite_tensor_extractor.h"
27 #include "mindspore/core/abstract/ops/primitive_infer_map.h"
28 #include "mindspore/core/utils/anf_utils.h"
29 #include "mindspore/core/ops/math_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 
32 /* This pass changes the following pattern(s).
33 
34   1. Getting shape from a dynamic tensor, and get item from the shape tuple.
35     ###############
36     Pattern:
37     Shape -> TupleGetItem -> Scalar
38 
39     Replace:
40     TensorShape -> Cast(int32) -> StridedSlice -> TensorToScalar -> Scalar
41     ###############
42 
43     The Shape op will be replaced by TensorShape op in order to get a Tensor, and then casted to int32 dtype for Ascend
44     dynamic shape mbatch calculation. Followed by StridedSlice, which is a Tensor equivalent to tuple's TupleGetItem.
45     And finally casted back to a Scalar by using TensorToScalar op, which is to make sure the old/new pattern outputs'
46     types are agree.
47 
48   2. Do scalar arithmetic using scalar from TupleGetItem op.
49     ###############
50     Pattern:
51     Scalar -> ScalarMul/ScalarDiv/…. -> Scalar
52 
53     Replace:
54     ScalarToTensor -> Mul/Div/… -> Tensor -> TensorToScalar
55     ###############
56 
57     The ScalarXXX arithmetic ops will be replaced by Tensor equivalent arithmetic ops. ScalarToTensor and TensorToScalar
58     conversion ops are inserted before and after, in order to make sure the old/new pattern inputs/outputs' types are
59     agree.
60 
61   3. MakeTuple and Reshape operations using Scalars.
62     ###############
63     Pattern:
64     Scalar -> MakeTuple -> Tuple(Scalar) -> Reshape -> Tensor
65 
66     Replace:
67     ScalarToTensor -> MakeTuple -> Tuple(Tensor) -> Concat -> Reshape -> Tensor
68     ###############
69 
70     MakeTuple's Scalar inputs will be converted to Tensors, followed by a Concat op to allow the reshape by a Tensor.
71     ScalarToTensor conversion op is inserted before to make sure the old/new pattern inputs/outputs' types are agree.
72 
73   4. ScalarToTensor and TensorToScalar ops are temporary placeholders. The last step is to remove them.
74     ###############
75     Pattern:
76     TensorToScalar -> ScalarToTensor
77     TensorToScalar -> Cast -> Tensor
78 
79     Replace:
80     remove both, the Tensor are connected.
81     ###############
82 */
83 namespace mindspore::opt {
84 /*
85 This function returns the index of the input node, which is used by the user node.
86 */
GetInputNodeIndex(const AnfNodePtr & input,const CNodePtr & user_node)87 size_t ScalarOpPass::GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
88   MS_EXCEPTION_IF_NULL(input);
89   MS_EXCEPTION_IF_NULL(user_node);
90 
91   AnfNodePtrList input_list = user_node->inputs();
92   auto pos = std::find(input_list.begin(), input_list.end(), input);
93   if (pos == input_list.end()) {
94     MS_LOG(EXCEPTION) << input->fullname_with_scope() << " is not the input of " << user_node->fullname_with_scope();
95   }
96 
97   // The first input is Primitive and needs to be skipped.
98   return std::distance(input_list.begin() + kSizeOne, pos);
99 }
100 
101 /*
102 Create a Tensor with type scalar. This pass assumes that the scalar is from TensorShape, which will be integers.
103 */
GenerateScalarValueTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,int input_index)104 ValueNodePtr ScalarOpPass::GenerateScalarValueTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
105                                                      int input_index) {
106   lite::DataInfo data_info;
107   auto ret = lite::FetchConstData(anf_node->cast<CNodePtr>(), input_index, converter::kFmkTypeMs, &data_info, false);
108   MS_CHECK_TRUE_RET(ret == lite::RET_OK, nullptr);
109   if (data_info.data_type_ != kNumberTypeInt32 && data_info.data_type_ != kNumberTypeInt64) {
110     MS_LOG(ERROR) << "Unsupported scalar data type: " << data_info.data_type_ << ", need to add support.";
111     return nullptr;
112   }
113   int32_t scalar_value = *reinterpret_cast<int32_t *>(data_info.data_.data());
114   ShapeVector const_data_shape = {1};
115   tensor::TensorPtr const_data_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, const_data_shape);
116   auto *val = static_cast<int32_t *>(const_data_tensor->data_c());
117   *val = scalar_value;
118   auto const_value_node = NewValueNode(const_data_tensor);
119   const_value_node->set_abstract(const_data_tensor->ToAbstract());
120   func_graph->AddValueNode(const_value_node);
121   return const_value_node;
122 }
123 
GenerateScalarToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,int input_index)124 CNodePtr ScalarOpPass::GenerateScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
125                                               int input_index) {
126   auto scalar_cnode = anf_node->cast<CNodePtr>();
127   auto scalar_input = scalar_cnode->input(input_index);
128   // Data type of the tensor should be set as an attr of ScalarToTensor op.
129   TypeId data_type;
130   if (opt::GetDataTypeFromAnfNode(scalar_cnode->input(input_index), &data_type) != RET_OK) {
131     MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
132     return nullptr;
133   }
134   auto type_id_value_node = NewValueNode(MakeValue(static_cast<int64_t>(data_type)));
135   auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
136   type_id_value_node->set_abstract(type_id_value->ToAbstract());
137   auto prim = NewValueNode(std::make_shared<Primitive>(kScalarToTensorOpName));
138   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
139   AnfNodePtrList inputs = {prim, scalar_input, type_id_value_node};
140   CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
141   MS_CHECK_TRUE_RET(scalar_to_tensor != nullptr, nullptr);
142   auto primitive = GetCNodePrimitive(scalar_to_tensor);
143   MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
144   // set abstract
145   ShapeVector tensor_shape = {1};
146   auto tensor_shape_ptr = std::make_shared<abstract::Shape>(tensor_shape);
147   MS_CHECK_TRUE_MSG(tensor_shape_ptr != nullptr, nullptr, "tensor_shape_ptr is nullptr.");
148   auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape), TypeIdToType(data_type));
149   MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractTensor failed");
150   scalar_to_tensor->set_abstract(tmp_abstract);
151   return scalar_to_tensor;
152 }
153 
GenerateTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,bool is_curr_node)154 CNodePtr ScalarOpPass::GenerateTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
155                                               bool is_curr_node) {
156   auto prim = NewValueNode(std::make_shared<Primitive>(kTensorToScalarOpName));
157   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
158   auto input_cnode = anf_node->cast<CNodePtr>();
159   AnfNodePtrList inputs = {prim, input_cnode->input(kIndexOne)};
160   if (is_curr_node) {  // insert TensorToScalar after current anf_node
161     inputs = {prim, anf_node};
162   }
163   CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
164   MS_CHECK_TRUE_RET(tensor_to_scalar != nullptr, nullptr);
165 
166   // set abstract
167   TypeId type_id;
168   (void)GetDataTypeFromAnfNode(anf_node, &type_id);
169 
170   auto tmp_abstract = std::make_shared<abstract::AbstractScalar>(kValueAny, TypeIdToType(type_id));
171   MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractScalar failed");
172   tensor_to_scalar->set_abstract(tmp_abstract);
173   return tensor_to_scalar;
174 }
175 
GenerateTensorShape(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)176 CNodePtr ScalarOpPass::GenerateTensorShape(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
177   auto prim = NewValueNode(std::make_shared<Primitive>(kTensorShapeOpName));
178   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
179   auto shape_cnode = anf_node->cast<CNodePtr>();
180   AnfNodePtrList inputs = {prim, shape_cnode->input(kIndexOne)};
181 
182   CNodePtr tensor_shape_node = func_graph->NewCNode(inputs);
183   MS_CHECK_TRUE_MSG(tensor_shape_node != nullptr, nullptr, "tensor_shape_node is nullptr.");
184 
185   abstract::AbstractBasePtr tmp_abstract;
186   auto shape_input_abs = shape_cnode->input(kIndexOne)->abstract()->cast<abstract::AbstractTensorPtr>();
187   MS_CHECK_TRUE_MSG(shape_input_abs != nullptr, nullptr, "shape input abstract is not AbstractTensor.");
188   auto shape = shape_input_abs->shape()->shape();
189   ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
190   if (IsDynamic(shape)) {
191     if (IsDynamicRank(shape)) {
192       tmp_abstract = abstract::MakeAbstract(
193         std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny}), kInt64);
194     } else {
195       auto elem = std::make_shared<abstract::AbstractScalar>(std::make_shared<ValueAny>(), std::make_shared<Int>(64));
196       auto abs_tensor = std::make_shared<abstract::AbstractTensor>(elem, std::make_shared<abstract::Shape>(tensor_shp));
197       tmp_abstract = abs_tensor;
198     }
199   } else {
200     auto shp_buf_size = sizeof(int64_t) * shape.size();
201     auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
202     tmp_abstract = tensor->ToAbstract();
203   }
204 
205   // set abstract
206   tensor_shape_node->set_fullname_with_scope(anf_node->fullname_with_scope() + "_tensorshape");
207   tensor_shape_node->set_abstract(tmp_abstract);
208   return tensor_shape_node;
209 }
210 
211 /*
212 Create a ValueNode with single scalar as input.
213 */
GenerateScalarValueTuple(const FuncGraphPtr & func_graph,int64_t value)214 ValueNodePtr ScalarOpPass::GenerateScalarValueTuple(const FuncGraphPtr &func_graph, int64_t value) {
215   std::vector<int64_t> vec({value});
216   auto tuple_value = MakeValue(vec);
217   auto tuple_node = NewValueNode(tuple_value);
218   tuple_node->set_abstract(tuple_value->ToAbstract());
219   func_graph->AddValueNode(tuple_node);
220   return tuple_node;
221 }
222 
GenerateScalarValue(const FuncGraphPtr & func_graph,int64_t value)223 ValueNodePtr ScalarOpPass::GenerateScalarValue(const FuncGraphPtr &func_graph, int64_t value) {
224   auto scalar_value = MakeValue(value);
225   auto scalar_node = NewValueNode(scalar_value);
226   scalar_node->set_abstract(scalar_value->ToAbstract());
227   func_graph->AddValueNode(scalar_node);
228   return scalar_node;
229 }
230 
GenerateStridedSlice(const FuncGraphPtr & func_graph,const AnfNodePtr & shape_node,const AnfNodePtr & tuple_get_node,const FuncGraphManagerPtr & manager)231 CNodePtr ScalarOpPass::GenerateStridedSlice(const FuncGraphPtr &func_graph, const AnfNodePtr &shape_node,
232                                             const AnfNodePtr &tuple_get_node, const FuncGraphManagerPtr &manager) {
233   auto begin_index = GetTupleGetItemOutIndex(tuple_get_node->cast<CNodePtr>());
234   MS_CHECK_TRUE_MSG(begin_index >= 0, nullptr, "begin index is less than zero.");
235 
236   // set inputs
237   auto begin_node = GenerateScalarValueTuple(func_graph, begin_index);
238   MS_CHECK_TRUE_MSG(begin_node != nullptr, nullptr, "generate StridedSlice begin node failed.");
239   auto end_node = GenerateScalarValueTuple(func_graph, begin_index + kSizeOne);
240   MS_CHECK_TRUE_MSG(end_node != nullptr, nullptr, "generate StridedSlice end node failed.");
241   auto strides_node = GenerateScalarValueTuple(func_graph, kSizeOne);
242   MS_CHECK_TRUE_MSG(strides_node != nullptr, nullptr, "generate StridedSlice strides node failed.");
243 
244   // set abstract
245   ShapeVector tensor_shape = {1};
246   auto tensor_shape_ptr = std::make_shared<abstract::Shape>(tensor_shape);
247   MS_CHECK_TRUE_MSG(tensor_shape_ptr != nullptr, nullptr, "tensor_shape_ptr is nullptr.");
248   TypeId infer_type;
249   auto ret = GetDataTypeFromAnfNode(shape_node, &infer_type);
250   MS_CHECK_TRUE_MSG(ret == RET_OK, nullptr, "get data_type from node failed.");
251 
252   auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape), TypeIdToType(infer_type));
253   MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractTensor failed");
254 
255   auto begin_mask = GenerateScalarValue(func_graph, 0);
256   MS_CHECK_TRUE_MSG(begin_mask != nullptr, nullptr, "generate StridedSlice begin_mask node failed.");
257   auto end_mask = GenerateScalarValue(func_graph, 0);
258   MS_CHECK_TRUE_MSG(end_mask != nullptr, nullptr, "generate StridedSlice end_mask node failed.");
259   auto ellipsis_mask = GenerateScalarValue(func_graph, 0);
260   MS_CHECK_TRUE_MSG(ellipsis_mask != nullptr, nullptr, "generate StridedSlice ellipsis_mask node failed.");
261   auto new_axis_mask = GenerateScalarValue(func_graph, 0);
262   MS_CHECK_TRUE_MSG(new_axis_mask != nullptr, nullptr, "generate StridedSlice new_axis_mask node failed.");
263   auto shrink_axis_mask = GenerateScalarValue(func_graph, 0);
264   MS_CHECK_TRUE_MSG(shrink_axis_mask != nullptr, nullptr, "generate StridedSlice shrink_axis_mask node failed.");
265 
266   auto prim = NewValueNode(std::make_shared<Primitive>(kStridedSliceOpName));
267   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
268   AnfNodePtrList inputs = {prim,       shape_node, begin_node,    end_node,      strides_node,
269                            begin_mask, end_mask,   ellipsis_mask, new_axis_mask, shrink_axis_mask};
270   CNodePtr strided_slice = func_graph->NewCNode(inputs);
271   MS_CHECK_TRUE_RET(strided_slice != nullptr, nullptr);
272   strided_slice->set_fullname_with_scope(tuple_get_node->fullname_with_scope() + "_strided_slice");
273   strided_slice->set_abstract(tmp_abstract);
274 
275   // set attrs, all defaults to zero
276   auto primitive = GetCNodePrimitive(strided_slice);
277   MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
278   return strided_slice;
279 }
280 
ReplaceScalarOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager,const PrimitivePtr & replace_op_prim)281 STATUS ScalarOpPass::ReplaceScalarOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
282                                      const FuncGraphManagerPtr &manager, const PrimitivePtr &replace_op_prim) {
283   auto replace_op_prim_node = NewValueNode(replace_op_prim);
284   MS_CHECK_TRUE_RET(replace_op_prim_node != nullptr, lite::RET_ERROR);
285   AnfNodePtrList replace_op_inputs = {replace_op_prim_node};
286 
287   auto scalar_cnode = anf_node->cast<CNodePtr>();
288   std::vector<mindspore::AnfNodePtr> scalar_inputs = {};
289   scalar_inputs.push_back(scalar_cnode->input(kIndexOne));
290   scalar_inputs.push_back(scalar_cnode->input(kIndexTwo));
291   for (size_t i = 0; i < scalar_inputs.size(); i++) {
292     if (!scalar_inputs[i]->isa<ValueNode>()) {
293       auto node = GenerateScalarToTensor(func_graph, anf_node, i + kSizeOne);
294       MS_CHECK_TRUE_RET(node != nullptr, lite::RET_ERROR);
295       replace_op_inputs.push_back(node);
296     } else {
297       auto node = GenerateScalarValueTensor(func_graph, anf_node, i + kSizeOne);
298       MS_CHECK_TRUE_RET(node != nullptr, lite::RET_ERROR);
299       replace_op_inputs.push_back(node);
300     }
301   }
302   CNodePtr replace_op = func_graph->NewCNode(replace_op_inputs);
303   MS_CHECK_TRUE_RET(replace_op != nullptr, lite::RET_ERROR);
304 
305   ShapeVector replace_op_shape = {1};
306   auto replace_op_shape_ptr = std::make_shared<abstract::Shape>(replace_op_shape);
307   MS_CHECK_TRUE_MSG(replace_op_shape_ptr != nullptr, RET_ERROR, "replace op is nullptr.");
308 
309   // Replace op has the same type as the first input
310   auto abstract = replace_op->input(kIndexOne)->abstract();
311   auto tmp_abstract = abstract->Clone();
312   tmp_abstract->set_shape(replace_op_shape_ptr);
313   replace_op->set_abstract(tmp_abstract);
314 
315   CNodePtr tensor_to_scalar = GenerateTensorToScalar(func_graph, replace_op, true);
316 
317   // Set input of the Scalar op users to tensor_to_scalar
318   auto orig_scalar_op_cnode = anf_node->cast<CNodePtr>();
319   auto node_users = manager->node_users()[orig_scalar_op_cnode];
320   for (auto &node_user : node_users) {
321     auto post_cnode = node_user.first->cast<CNodePtr>();
322     MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
323     manager->SetEdge(post_cnode, GetInputNodeIndex(anf_node, post_cnode) + kSizeOne, tensor_to_scalar);
324   }
325 
326   return lite::RET_OK;
327 }
328 
ReplaceMakeTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)329 STATUS ScalarOpPass::ReplaceMakeTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
330                                       const FuncGraphManagerPtr &manager) {
331   auto make_tuple_cnode = anf_node->cast<CNodePtr>();
332   if (!utils::isa<mindspore::abstract::AbstractScalarPtr>(make_tuple_cnode->input(kIndexOne)->abstract())) {
333     return lite::RET_NO_CHANGE;
334   }
335 
336   abstract::BaseShapePtrList tuple_shape_list;
337   TypePtrList tuple_type_list;
338   for (size_t i = kIndexOne; i < make_tuple_cnode->size(); i++) {
339     auto make_tuple_input = make_tuple_cnode->input(i);
340 
341     // Parse abstract shape for modified MakeTuple
342     ShapeVector scalar_shape = {1};
343     tuple_shape_list.push_back(std::make_shared<abstract::Shape>(scalar_shape));
344 
345     // Parse abstract type for modified MakeTuple
346     TypeId orig_type_id;
347     auto ret = GetDataTypeFromAnfNode(make_tuple_input, &orig_type_id);
348     MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
349 
350     // Insert ScalarToTensor before MakeTuple
351     if (!make_tuple_input->isa<ValueNode>()) {
352       auto node = GenerateScalarToTensor(func_graph, anf_node, i);
353       MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "generate ScalarToTensor node failed.");
354 
355       // Parse abstract type for modified MakeTuple
356       ret = GetDataTypeFromAnfNode(node, &orig_type_id);
357       MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
358       tuple_type_list.push_back(TypeIdToType(orig_type_id));
359 
360       manager->SetEdge(anf_node, i, node);
361     } else {  // For ValueNode the input type is int32
362       auto node = GenerateScalarValueTensor(func_graph, anf_node, i);
363       MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "generate ScalarValueTensor node failed.");
364 
365       // Parse abstract type for modified MakeTuple
366       ret = GetDataTypeFromAnfNode(node, &orig_type_id);
367       MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
368       tuple_type_list.push_back(TypeIdToType(orig_type_id));
369 
370       manager->SetEdge(anf_node, i, node);
371     }
372   }
373 
374   // Apply modified abstract to MakeTuple
375   auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(tuple_shape_list),
376                                              std::make_shared<Tuple>(tuple_type_list));
377   anf_node->set_abstract(tmp_abstract);
378 
379   // Insert concat after MakeTuple
380   std::vector<AnfNodePtr> concat_input_vec({anf_node});
381   auto concat_node = GenConcatNode(func_graph, concat_input_vec,
382                                    anf_node->cast<CNodePtr>()->fullname_with_scope() + "_concat_make_tuple");
383   auto primitive = GetCNodePrimitive(concat_node);
384   MS_CHECK_TRUE_RET(primitive != nullptr, lite::RET_ERROR);
385   int64_t num_of_inputs = SizeToInt(anf_node->cast<CNodePtr>()->size() - kSizeOne);
386   primitive->set_attr("N", MakeValue<int64_t>(num_of_inputs));
387   primitive->set_attr("inputNums", MakeValue<int64_t>(num_of_inputs));
388 
389   // The first input type is used as the type for concat (need to add type check)
390   TypeId make_tuple_type;
391   if (opt::GetDataTypeFromAnfNode(anf_node, &make_tuple_type) != RET_OK) {
392     MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
393     return lite::RET_ERROR;
394   }
395   auto concat_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(ShapeVector({num_of_inputs})),
396                                                 TypeIdToType(make_tuple_type));
397   concat_node->set_abstract(concat_abstract);
398 
399   // set MakeTuple users' input to concat
400   auto make_tuple_users = manager->node_users()[anf_node];
401   for (auto &make_tuple_user : make_tuple_users) {
402     auto post_cnode = make_tuple_user.first->cast<CNodePtr>();
403     MS_CHECK_TRUE_MSG(post_cnode != nullptr, lite::RET_ERROR, "MakeTuple user is null.");
404     manager->SetEdge(post_cnode, GetInputNodeIndex(anf_node, post_cnode) + kSizeOne, concat_node);
405   }
406 
407   return lite::RET_OK;
408 }
409 
ReplaceShapeTupleGet(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)410 STATUS ScalarOpPass::ReplaceShapeTupleGet(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
411                                           const FuncGraphManagerPtr &manager) {
412   auto shape_cnode = anf_node->cast<CNodePtr>();
413 
414   // Replace Shape by TensorShape
415   auto tensor_shape_node = GenerateTensorShape(func_graph, anf_node);
416   MS_CHECK_TRUE_MSG(tensor_shape_node != nullptr, lite::RET_ERROR, "generate TensorShape node failed.");
417   ShapeVector tensor_shape_shape;
418   auto ret = FetchShapeFromAbstract(tensor_shape_node->abstract(), &tensor_shape_shape);
419   MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "fetch shape from TensorShape node failed.");
420   auto cast_abstract =
421     abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape_shape), TypeIdToType(kNumberTypeInt32));
422 
423   CNodePtr cast_node = nullptr;
424   auto shape_users = manager->node_users()[shape_cnode];
425   for (auto &shape_user : shape_users) {
426     auto tuple_get_node = shape_user.first->cast<CNodePtr>();
427     if (CheckPrimitiveType(tuple_get_node, prim::kPrimTupleGetItem)) {
428       if (cast_node == nullptr) {
429         cast_node =
430           GenCastNode(func_graph, tensor_shape_node, tensor_shape_node->fullname_with_scope() + "_cast_tensorshape",
431                       kNumberTypeInt32, cast_abstract);
432       }
433       auto strided_slice_node = GenerateStridedSlice(func_graph, cast_node, tuple_get_node, manager);
434       MS_CHECK_TRUE_MSG(strided_slice_node != nullptr, lite::RET_ERROR, "generate StridedSlice node failed.");
435 
436       CNodePtr tensor_to_scalar = GenerateTensorToScalar(func_graph, strided_slice_node, true);
437       MS_CHECK_TRUE_MSG(tensor_to_scalar != nullptr, lite::RET_ERROR, "generate TensorToScalar node failed.");
438 
439       auto tuple_get_users = manager->node_users()[tuple_get_node];
440       for (auto &tuple_get_user : tuple_get_users) {
441         auto post_cnode = tuple_get_user.first->cast<CNodePtr>();
442         MS_CHECK_TRUE_MSG(post_cnode != nullptr, lite::RET_ERROR, "TupleGetItem user is null.");
443         manager->SetEdge(post_cnode, GetInputNodeIndex(tuple_get_node, post_cnode) + kSizeOne, tensor_to_scalar);
444       }
445     }
446   }
447 
448   return lite::RET_OK;
449 }
450 
RemoveTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)451 STATUS ScalarOpPass::RemoveTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
452                                           const FuncGraphManagerPtr &manager) {
453   auto tensor_to_scalar_cnode = anf_node->cast<CNodePtr>();
454   auto tensor_to_scalar_users = manager->node_users()[tensor_to_scalar_cnode];
455   auto parent_node = tensor_to_scalar_cnode->input(kIndexOne);
456   for (auto &tensor_to_scalar_user : tensor_to_scalar_users) {
457     auto user_node = tensor_to_scalar_user.first->cast<CNodePtr>();
458     if (CheckPrimitiveType(user_node, prim::kPrimScalarToTensor) || CheckPrimitiveType(user_node, prim::kPrimCast)) {
459       auto child_node_users = manager->node_users()[user_node];
460       for (auto &child_node_user : child_node_users) {
461         auto child_node = child_node_user.first->cast<CNodePtr>();
462         manager->SetEdge(child_node, GetInputNodeIndex(user_node, child_node) + kSizeOne, parent_node);
463       }
464     } else {
465       std::string prim_name = "";
466       (void)GetPrimitiveType(user_node, &prim_name);
467       MS_LOG(ERROR) << "Cannot handle primitive " << prim_name << " after TensorToScalar, please check graph.";
468       return lite::RET_ERROR;
469     }
470   }
471   return lite::RET_OK;
472 }
473 
RunScalarOpPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)474 STATUS ScalarOpPass::RunScalarOpPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
475   auto node_list = TopoSort(func_graph->get_return());
476   STATUS status = lite::RET_NO_CHANGE;
477   for (auto &node : node_list) {
478     if (!utils::isa<CNodePtr>(node)) {
479       continue;
480     }
481     // First replace all Scalar ops to tensor equivalents
482     if (CheckPrimitiveType(node, prim::kPrimScalarMul)) {
483       status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimMul);
484     } else if (CheckPrimitiveType(node, prim::kPrimScalarDiv)) {
485       status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimRealDiv);
486     } else if (CheckPrimitiveType(node, prim::kPrimScalarFloorDiv)) {
487       status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimFloorDiv);
488     } else if (CheckPrimitiveType(node, prim::kPrimScalarSub)) {
489       status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimSub);
490     } else if (CheckPrimitiveType(node, prim::kPrimScalarAdd)) {
491       status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimAdd);
492     } else if (CheckPrimitiveType(node, prim::kPrimScalarCast)) {
493       MS_LOG(ERROR) << "For models with dynamic input shapes, ScalarCast node conversion has not been supported yet, "
494                        "please check cast operations such as \"int(some_var)\" in the front-end code and remove them.";
495       status = lite::RET_NOT_SUPPORT;
496     }
497 
498     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
499       MS_LOG(ERROR) << "Failed to run scalar op pass at cnode: " << node->fullname_with_scope();
500       return lite::RET_ERROR;
501     }
502   }
503   return status;
504 }
505 
RunMakeTuplePass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)506 STATUS ScalarOpPass::RunMakeTuplePass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
507   auto node_list = TopoSort(func_graph->get_return());
508   auto status = lite::RET_OK;
509   for (auto &node : node_list) {
510     if (!utils::isa<CNodePtr>(node)) {
511       continue;
512     }
513     // Then change MakeTuple's input from a tuple of scalars to a tuple of tensors
514     if (CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
515       status = this->ReplaceMakeTuple(func_graph, node, manager);
516     }
517 
518     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
519       MS_LOG(ERROR) << "Failed to run make tuple pass at cnode: " << node->fullname_with_scope();
520       return lite::RET_ERROR;
521     }
522   }
523   return lite::RET_OK;
524 }
525 
RunShapeTupleGetPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)526 STATUS ScalarOpPass::RunShapeTupleGetPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
527   auto node_list = TopoSort(func_graph->get_return());
528   auto status = lite::RET_OK;
529   for (auto &node : node_list) {
530     if (!utils::isa<CNodePtr>(node)) {
531       continue;
532     }
533     // Replace Shape+TupleGetItem to TensorShape+StridedSlice
534     if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
535       auto tuple_get_input = node->cast<CNodePtr>()->input(kIndexOne);
536       if (CheckPrimitiveType(tuple_get_input, prim::kPrimShape)) {
537         MS_LOG(INFO) << "Start processing Shape + TupleGetItem pass.";
538         status = this->ReplaceShapeTupleGet(func_graph, tuple_get_input, manager);
539       }
540     }
541 
542     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
543       MS_LOG(ERROR) << "Failed to run shape tuple get pass at cnode: " << node->fullname_with_scope();
544       return lite::RET_ERROR;
545     }
546   }
547   return lite::RET_OK;
548 }
549 
RunRemoveTensorToScalarPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)550 STATUS ScalarOpPass::RunRemoveTensorToScalarPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
551   auto node_list = TopoSort(func_graph->get_return());
552   auto status = lite::RET_OK;
553   for (auto &node : node_list) {
554     if (!utils::isa<CNodePtr>(node)) {
555       continue;
556     }
557     // Remove TensorToScalar + ScalarToTensor
558     // Remove TensorToScalar + Cast
559     if (CheckPrimitiveType(node, prim::kPrimTensorToScalar)) {
560       MS_LOG(DEBUG) << "Found TensorToScalar, start removing...";
561       status = this->RemoveTensorToScalar(func_graph, node, manager);
562     }
563 
564     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
565       MS_LOG(ERROR) << "Failed to run remove TensorToScalar pass at cnode: " << node->fullname_with_scope();
566       return lite::RET_ERROR;
567     }
568   }
569   return lite::RET_OK;
570 }
571 
572 /*
573 This pass checks the arithmetic ops have correct infer types when all TensorToScalar/ScalarToTensor ops are removed. If
574 datatypes do not agree, insert cast op.
575 */
RunArithmeticCheckPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)576 STATUS ScalarOpPass::RunArithmeticCheckPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
577   auto node_list = TopoSort(func_graph->get_return());
578   for (auto &node : node_list) {
579     if (!utils::isa<CNodePtr>(node)) {
580       continue;
581     }
582     // Check arithmetic op infer type, insert cast op if two inputs do not agree
583     if (CheckPrimitiveType(node, prim::kPrimMul) || CheckPrimitiveType(node, prim::kPrimDiv) ||
584         CheckPrimitiveType(node, prim::kPrimFloorDiv) || CheckPrimitiveType(node, prim::kPrimRealDiv) ||
585         CheckPrimitiveType(node, prim::kPrimSub)) {
586       auto first_input = node->cast<CNodePtr>()->input(kIndexOne);
587       auto second_input = node->cast<CNodePtr>()->input(kIndexTwo);
588 
589       TypeId first_data_type;
590       if (opt::GetDataTypeFromAnfNode(first_input, &first_data_type) != RET_OK) {
591         MS_LOG(ERROR) << "Failed to get arithmetic op first input tensor data type.";
592         return lite::RET_ERROR;
593       }
594       TypeId second_data_type;
595       if (opt::GetDataTypeFromAnfNode(second_input, &second_data_type) != RET_OK) {
596         MS_LOG(ERROR) << "Failed to get arithmetic op second input tensor data type.";
597         return lite::RET_ERROR;
598       }
599       if (first_data_type == second_data_type) {
600         continue;
601       }
602 
603       // Insert cast node before second input, and set infer type the same as the first input
604       auto cast_data_type = first_data_type;
605       ShapeVector cast_shape;
606       if (FetchShapeFromAbstract(second_input->abstract(), &cast_shape) != lite::RET_OK) {
607         MS_LOG(ERROR) << "Fetch shape from second input abstract failed!";
608         return lite::RET_ERROR;
609       }
610       auto new_cast_abstract =
611         abstract::MakeAbstract(std::make_shared<abstract::Shape>(cast_shape), TypeIdToType(cast_data_type));
612       auto new_cast_node =
613         GenCastNode(func_graph, second_input, second_input->fullname_with_scope() + "cast_after_second_in",
614                     cast_data_type, new_cast_abstract);
615       new_cast_node->set_abstract(new_cast_abstract);
616       manager->SetEdge(node, kIndexTwo, new_cast_node);
617     }
618   }
619   return lite::RET_OK;
620 }
621 
Run(const FuncGraphPtr & func_graph)622 bool ScalarOpPass::Run(const FuncGraphPtr &func_graph) {
623   MS_ASSERT(func_graph != nullptr);
624   auto manager = func_graph->manager();
625   MS_CHECK_TRUE_RET(manager != nullptr, false);
626   auto status = RunShapeTupleGetPass(func_graph, manager);
627   MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
628   auto scalar_replace_status = RunScalarOpPass(func_graph, manager);
629   MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
630   status = RunMakeTuplePass(func_graph, manager);
631   MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
632   status = RunRemoveTensorToScalarPass(func_graph, manager);
633   MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
634   if (scalar_replace_status != lite::RET_NO_CHANGE) {
635     status = RunArithmeticCheckPass(func_graph, manager);
636     MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
637   }
638 
639   return true;
640 }
641 }  // namespace mindspore::opt
642