1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17 #include <vector>
18 #include <string>
19 #include "ops/auto_generate/gen_lite_ops.h"
20 #include "ops/auto_generate/gen_ops_primitive.h"
21 #include "ops/array_ops.h"
22 #include "ops/lite_ops.h"
23 #include "tools/optimizer/graph/scalar_op_pass.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "mindspore/core/ops/arithmetic_ops.h"
26 #include "tools/optimizer/graph/lite_tensor_extractor.h"
27 #include "mindspore/core/abstract/ops/primitive_infer_map.h"
28 #include "mindspore/core/utils/anf_utils.h"
29 #include "mindspore/core/ops/math_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31
32 /* This pass changes the following pattern(s).
33
34 1. Getting shape from a dynamic tensor, and get item from the shape tuple.
35 ###############
36 Pattern:
37 Shape -> TupleGetItem -> Scalar
38
39 Replace:
40 TensorShape -> Cast(int32) -> StridedSlice -> TensorToScalar -> Scalar
41 ###############
42
43 The Shape op will be replaced by TensorShape op in order to get a Tensor, and then casted to int32 dtype for Ascend
44 dynamic shape mbatch calculation. Followed by StridedSlice, which is a Tensor equivalent to tuple's TupleGetItem.
45 And finally casted back to a Scalar by using TensorToScalar op, which is to make sure the old/new pattern outputs'
46 types are agree.
47
48 2. Do scalar arithmetic using scalar from TupleGetItem op.
49 ###############
50 Pattern:
51 Scalar -> ScalarMul/ScalarDiv/…. -> Scalar
52
53 Replace:
54 ScalarToTensor -> Mul/Div/… -> Tensor -> TensorToScalar
55 ###############
56
57 The ScalarXXX arithmetic ops will be replaced by Tensor equivalent arithmetic ops. ScalarToTensor and TensorToScalar
58 conversion ops are inserted before and after, in order to make sure the old/new pattern inputs/outputs' types are
59 agree.
60
61 3. MakeTuple and Reshape operations using Scalars.
62 ###############
63 Pattern:
64 Scalar -> MakeTuple -> Tuple(Scalar) -> Reshape -> Tensor
65
66 Replace:
67 ScalarToTensor -> MakeTuple -> Tuple(Tensor) -> Concat -> Reshape -> Tensor
68 ###############
69
70 MakeTuple's Scalar inputs will be converted to Tensors, followed by a Concat op to allow the reshape by a Tensor.
71 ScalarToTensor conversion op is inserted before to make sure the old/new pattern inputs/outputs' types are agree.
72
73 4. ScalarToTensor and TensorToScalar ops are temporary placeholders. The last step is to remove them.
74 ###############
75 Pattern:
76 TensorToScalar -> ScalarToTensor
77 TensorToScalar -> Cast -> Tensor
78
79 Replace:
80 remove both, the Tensor are connected.
81 ###############
82 */
83 namespace mindspore::opt {
84 /*
85 This function returns the index of the input node, which is used by the user node.
86 */
GetInputNodeIndex(const AnfNodePtr & input,const CNodePtr & user_node)87 size_t ScalarOpPass::GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
88 MS_EXCEPTION_IF_NULL(input);
89 MS_EXCEPTION_IF_NULL(user_node);
90
91 AnfNodePtrList input_list = user_node->inputs();
92 auto pos = std::find(input_list.begin(), input_list.end(), input);
93 if (pos == input_list.end()) {
94 MS_LOG(EXCEPTION) << input->fullname_with_scope() << " is not the input of " << user_node->fullname_with_scope();
95 }
96
97 // The first input is Primitive and needs to be skipped.
98 return std::distance(input_list.begin() + kSizeOne, pos);
99 }
100
101 /*
102 Create a Tensor with type scalar. This pass assumes that the scalar is from TensorShape, which will be integers.
103 */
GenerateScalarValueTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,int input_index)104 ValueNodePtr ScalarOpPass::GenerateScalarValueTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
105 int input_index) {
106 lite::DataInfo data_info;
107 auto ret = lite::FetchConstData(anf_node->cast<CNodePtr>(), input_index, converter::kFmkTypeMs, &data_info, false);
108 MS_CHECK_TRUE_RET(ret == lite::RET_OK, nullptr);
109 if (data_info.data_type_ != kNumberTypeInt32 && data_info.data_type_ != kNumberTypeInt64) {
110 MS_LOG(ERROR) << "Unsupported scalar data type: " << data_info.data_type_ << ", need to add support.";
111 return nullptr;
112 }
113 int32_t scalar_value = *reinterpret_cast<int32_t *>(data_info.data_.data());
114 ShapeVector const_data_shape = {1};
115 tensor::TensorPtr const_data_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, const_data_shape);
116 auto *val = static_cast<int32_t *>(const_data_tensor->data_c());
117 *val = scalar_value;
118 auto const_value_node = NewValueNode(const_data_tensor);
119 const_value_node->set_abstract(const_data_tensor->ToAbstract());
120 func_graph->AddValueNode(const_value_node);
121 return const_value_node;
122 }
123
GenerateScalarToTensor(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,int input_index)124 CNodePtr ScalarOpPass::GenerateScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
125 int input_index) {
126 auto scalar_cnode = anf_node->cast<CNodePtr>();
127 auto scalar_input = scalar_cnode->input(input_index);
128 // Data type of the tensor should be set as an attr of ScalarToTensor op.
129 TypeId data_type;
130 if (opt::GetDataTypeFromAnfNode(scalar_cnode->input(input_index), &data_type) != RET_OK) {
131 MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
132 return nullptr;
133 }
134 auto type_id_value_node = NewValueNode(MakeValue(static_cast<int64_t>(data_type)));
135 auto type_id_value = std::make_shared<Int64Imm>(static_cast<int64_t>(data_type));
136 type_id_value_node->set_abstract(type_id_value->ToAbstract());
137 auto prim = NewValueNode(std::make_shared<Primitive>(kScalarToTensorOpName));
138 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
139 AnfNodePtrList inputs = {prim, scalar_input, type_id_value_node};
140 CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
141 MS_CHECK_TRUE_RET(scalar_to_tensor != nullptr, nullptr);
142 auto primitive = GetCNodePrimitive(scalar_to_tensor);
143 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
144 // set abstract
145 ShapeVector tensor_shape = {1};
146 auto tensor_shape_ptr = std::make_shared<abstract::Shape>(tensor_shape);
147 MS_CHECK_TRUE_MSG(tensor_shape_ptr != nullptr, nullptr, "tensor_shape_ptr is nullptr.");
148 auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape), TypeIdToType(data_type));
149 MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractTensor failed");
150 scalar_to_tensor->set_abstract(tmp_abstract);
151 return scalar_to_tensor;
152 }
153
GenerateTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,bool is_curr_node)154 CNodePtr ScalarOpPass::GenerateTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
155 bool is_curr_node) {
156 auto prim = NewValueNode(std::make_shared<Primitive>(kTensorToScalarOpName));
157 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
158 auto input_cnode = anf_node->cast<CNodePtr>();
159 AnfNodePtrList inputs = {prim, input_cnode->input(kIndexOne)};
160 if (is_curr_node) { // insert TensorToScalar after current anf_node
161 inputs = {prim, anf_node};
162 }
163 CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
164 MS_CHECK_TRUE_RET(tensor_to_scalar != nullptr, nullptr);
165
166 // set abstract
167 TypeId type_id;
168 (void)GetDataTypeFromAnfNode(anf_node, &type_id);
169
170 auto tmp_abstract = std::make_shared<abstract::AbstractScalar>(kValueAny, TypeIdToType(type_id));
171 MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractScalar failed");
172 tensor_to_scalar->set_abstract(tmp_abstract);
173 return tensor_to_scalar;
174 }
175
GenerateTensorShape(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)176 CNodePtr ScalarOpPass::GenerateTensorShape(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
177 auto prim = NewValueNode(std::make_shared<Primitive>(kTensorShapeOpName));
178 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
179 auto shape_cnode = anf_node->cast<CNodePtr>();
180 AnfNodePtrList inputs = {prim, shape_cnode->input(kIndexOne)};
181
182 CNodePtr tensor_shape_node = func_graph->NewCNode(inputs);
183 MS_CHECK_TRUE_MSG(tensor_shape_node != nullptr, nullptr, "tensor_shape_node is nullptr.");
184
185 abstract::AbstractBasePtr tmp_abstract;
186 auto shape_input_abs = shape_cnode->input(kIndexOne)->abstract()->cast<abstract::AbstractTensorPtr>();
187 MS_CHECK_TRUE_MSG(shape_input_abs != nullptr, nullptr, "shape input abstract is not AbstractTensor.");
188 auto shape = shape_input_abs->shape()->shape();
189 ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
190 if (IsDynamic(shape)) {
191 if (IsDynamicRank(shape)) {
192 tmp_abstract = abstract::MakeAbstract(
193 std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny}), kInt64);
194 } else {
195 auto elem = std::make_shared<abstract::AbstractScalar>(std::make_shared<ValueAny>(), std::make_shared<Int>(64));
196 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(elem, std::make_shared<abstract::Shape>(tensor_shp));
197 tmp_abstract = abs_tensor;
198 }
199 } else {
200 auto shp_buf_size = sizeof(int64_t) * shape.size();
201 auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
202 tmp_abstract = tensor->ToAbstract();
203 }
204
205 // set abstract
206 tensor_shape_node->set_fullname_with_scope(anf_node->fullname_with_scope() + "_tensorshape");
207 tensor_shape_node->set_abstract(tmp_abstract);
208 return tensor_shape_node;
209 }
210
211 /*
212 Create a ValueNode with single scalar as input.
213 */
GenerateScalarValueTuple(const FuncGraphPtr & func_graph,int64_t value)214 ValueNodePtr ScalarOpPass::GenerateScalarValueTuple(const FuncGraphPtr &func_graph, int64_t value) {
215 std::vector<int64_t> vec({value});
216 auto tuple_value = MakeValue(vec);
217 auto tuple_node = NewValueNode(tuple_value);
218 tuple_node->set_abstract(tuple_value->ToAbstract());
219 func_graph->AddValueNode(tuple_node);
220 return tuple_node;
221 }
222
GenerateScalarValue(const FuncGraphPtr & func_graph,int64_t value)223 ValueNodePtr ScalarOpPass::GenerateScalarValue(const FuncGraphPtr &func_graph, int64_t value) {
224 auto scalar_value = MakeValue(value);
225 auto scalar_node = NewValueNode(scalar_value);
226 scalar_node->set_abstract(scalar_value->ToAbstract());
227 func_graph->AddValueNode(scalar_node);
228 return scalar_node;
229 }
230
GenerateStridedSlice(const FuncGraphPtr & func_graph,const AnfNodePtr & shape_node,const AnfNodePtr & tuple_get_node,const FuncGraphManagerPtr & manager)231 CNodePtr ScalarOpPass::GenerateStridedSlice(const FuncGraphPtr &func_graph, const AnfNodePtr &shape_node,
232 const AnfNodePtr &tuple_get_node, const FuncGraphManagerPtr &manager) {
233 auto begin_index = GetTupleGetItemOutIndex(tuple_get_node->cast<CNodePtr>());
234 MS_CHECK_TRUE_MSG(begin_index >= 0, nullptr, "begin index is less than zero.");
235
236 // set inputs
237 auto begin_node = GenerateScalarValueTuple(func_graph, begin_index);
238 MS_CHECK_TRUE_MSG(begin_node != nullptr, nullptr, "generate StridedSlice begin node failed.");
239 auto end_node = GenerateScalarValueTuple(func_graph, begin_index + kSizeOne);
240 MS_CHECK_TRUE_MSG(end_node != nullptr, nullptr, "generate StridedSlice end node failed.");
241 auto strides_node = GenerateScalarValueTuple(func_graph, kSizeOne);
242 MS_CHECK_TRUE_MSG(strides_node != nullptr, nullptr, "generate StridedSlice strides node failed.");
243
244 // set abstract
245 ShapeVector tensor_shape = {1};
246 auto tensor_shape_ptr = std::make_shared<abstract::Shape>(tensor_shape);
247 MS_CHECK_TRUE_MSG(tensor_shape_ptr != nullptr, nullptr, "tensor_shape_ptr is nullptr.");
248 TypeId infer_type;
249 auto ret = GetDataTypeFromAnfNode(shape_node, &infer_type);
250 MS_CHECK_TRUE_MSG(ret == RET_OK, nullptr, "get data_type from node failed.");
251
252 auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape), TypeIdToType(infer_type));
253 MS_CHECK_TRUE_MSG(tmp_abstract != nullptr, nullptr, "make AbstractTensor failed");
254
255 auto begin_mask = GenerateScalarValue(func_graph, 0);
256 MS_CHECK_TRUE_MSG(begin_mask != nullptr, nullptr, "generate StridedSlice begin_mask node failed.");
257 auto end_mask = GenerateScalarValue(func_graph, 0);
258 MS_CHECK_TRUE_MSG(end_mask != nullptr, nullptr, "generate StridedSlice end_mask node failed.");
259 auto ellipsis_mask = GenerateScalarValue(func_graph, 0);
260 MS_CHECK_TRUE_MSG(ellipsis_mask != nullptr, nullptr, "generate StridedSlice ellipsis_mask node failed.");
261 auto new_axis_mask = GenerateScalarValue(func_graph, 0);
262 MS_CHECK_TRUE_MSG(new_axis_mask != nullptr, nullptr, "generate StridedSlice new_axis_mask node failed.");
263 auto shrink_axis_mask = GenerateScalarValue(func_graph, 0);
264 MS_CHECK_TRUE_MSG(shrink_axis_mask != nullptr, nullptr, "generate StridedSlice shrink_axis_mask node failed.");
265
266 auto prim = NewValueNode(std::make_shared<Primitive>(kStridedSliceOpName));
267 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
268 AnfNodePtrList inputs = {prim, shape_node, begin_node, end_node, strides_node,
269 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask};
270 CNodePtr strided_slice = func_graph->NewCNode(inputs);
271 MS_CHECK_TRUE_RET(strided_slice != nullptr, nullptr);
272 strided_slice->set_fullname_with_scope(tuple_get_node->fullname_with_scope() + "_strided_slice");
273 strided_slice->set_abstract(tmp_abstract);
274
275 // set attrs, all defaults to zero
276 auto primitive = GetCNodePrimitive(strided_slice);
277 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
278 return strided_slice;
279 }
280
ReplaceScalarOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager,const PrimitivePtr & replace_op_prim)281 STATUS ScalarOpPass::ReplaceScalarOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
282 const FuncGraphManagerPtr &manager, const PrimitivePtr &replace_op_prim) {
283 auto replace_op_prim_node = NewValueNode(replace_op_prim);
284 MS_CHECK_TRUE_RET(replace_op_prim_node != nullptr, lite::RET_ERROR);
285 AnfNodePtrList replace_op_inputs = {replace_op_prim_node};
286
287 auto scalar_cnode = anf_node->cast<CNodePtr>();
288 std::vector<mindspore::AnfNodePtr> scalar_inputs = {};
289 scalar_inputs.push_back(scalar_cnode->input(kIndexOne));
290 scalar_inputs.push_back(scalar_cnode->input(kIndexTwo));
291 for (size_t i = 0; i < scalar_inputs.size(); i++) {
292 if (!scalar_inputs[i]->isa<ValueNode>()) {
293 auto node = GenerateScalarToTensor(func_graph, anf_node, i + kSizeOne);
294 MS_CHECK_TRUE_RET(node != nullptr, lite::RET_ERROR);
295 replace_op_inputs.push_back(node);
296 } else {
297 auto node = GenerateScalarValueTensor(func_graph, anf_node, i + kSizeOne);
298 MS_CHECK_TRUE_RET(node != nullptr, lite::RET_ERROR);
299 replace_op_inputs.push_back(node);
300 }
301 }
302 CNodePtr replace_op = func_graph->NewCNode(replace_op_inputs);
303 MS_CHECK_TRUE_RET(replace_op != nullptr, lite::RET_ERROR);
304
305 ShapeVector replace_op_shape = {1};
306 auto replace_op_shape_ptr = std::make_shared<abstract::Shape>(replace_op_shape);
307 MS_CHECK_TRUE_MSG(replace_op_shape_ptr != nullptr, RET_ERROR, "replace op is nullptr.");
308
309 // Replace op has the same type as the first input
310 auto abstract = replace_op->input(kIndexOne)->abstract();
311 auto tmp_abstract = abstract->Clone();
312 tmp_abstract->set_shape(replace_op_shape_ptr);
313 replace_op->set_abstract(tmp_abstract);
314
315 CNodePtr tensor_to_scalar = GenerateTensorToScalar(func_graph, replace_op, true);
316
317 // Set input of the Scalar op users to tensor_to_scalar
318 auto orig_scalar_op_cnode = anf_node->cast<CNodePtr>();
319 auto node_users = manager->node_users()[orig_scalar_op_cnode];
320 for (auto &node_user : node_users) {
321 auto post_cnode = node_user.first->cast<CNodePtr>();
322 MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
323 manager->SetEdge(post_cnode, GetInputNodeIndex(anf_node, post_cnode) + kSizeOne, tensor_to_scalar);
324 }
325
326 return lite::RET_OK;
327 }
328
ReplaceMakeTuple(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)329 STATUS ScalarOpPass::ReplaceMakeTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
330 const FuncGraphManagerPtr &manager) {
331 auto make_tuple_cnode = anf_node->cast<CNodePtr>();
332 if (!utils::isa<mindspore::abstract::AbstractScalarPtr>(make_tuple_cnode->input(kIndexOne)->abstract())) {
333 return lite::RET_NO_CHANGE;
334 }
335
336 abstract::BaseShapePtrList tuple_shape_list;
337 TypePtrList tuple_type_list;
338 for (size_t i = kIndexOne; i < make_tuple_cnode->size(); i++) {
339 auto make_tuple_input = make_tuple_cnode->input(i);
340
341 // Parse abstract shape for modified MakeTuple
342 ShapeVector scalar_shape = {1};
343 tuple_shape_list.push_back(std::make_shared<abstract::Shape>(scalar_shape));
344
345 // Parse abstract type for modified MakeTuple
346 TypeId orig_type_id;
347 auto ret = GetDataTypeFromAnfNode(make_tuple_input, &orig_type_id);
348 MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
349
350 // Insert ScalarToTensor before MakeTuple
351 if (!make_tuple_input->isa<ValueNode>()) {
352 auto node = GenerateScalarToTensor(func_graph, anf_node, i);
353 MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "generate ScalarToTensor node failed.");
354
355 // Parse abstract type for modified MakeTuple
356 ret = GetDataTypeFromAnfNode(node, &orig_type_id);
357 MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
358 tuple_type_list.push_back(TypeIdToType(orig_type_id));
359
360 manager->SetEdge(anf_node, i, node);
361 } else { // For ValueNode the input type is int32
362 auto node = GenerateScalarValueTensor(func_graph, anf_node, i);
363 MS_CHECK_TRUE_MSG(node != nullptr, lite::RET_ERROR, "generate ScalarValueTensor node failed.");
364
365 // Parse abstract type for modified MakeTuple
366 ret = GetDataTypeFromAnfNode(node, &orig_type_id);
367 MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "get datatype from MakeTuple input failed.");
368 tuple_type_list.push_back(TypeIdToType(orig_type_id));
369
370 manager->SetEdge(anf_node, i, node);
371 }
372 }
373
374 // Apply modified abstract to MakeTuple
375 auto tmp_abstract = abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(tuple_shape_list),
376 std::make_shared<Tuple>(tuple_type_list));
377 anf_node->set_abstract(tmp_abstract);
378
379 // Insert concat after MakeTuple
380 std::vector<AnfNodePtr> concat_input_vec({anf_node});
381 auto concat_node = GenConcatNode(func_graph, concat_input_vec,
382 anf_node->cast<CNodePtr>()->fullname_with_scope() + "_concat_make_tuple");
383 auto primitive = GetCNodePrimitive(concat_node);
384 MS_CHECK_TRUE_RET(primitive != nullptr, lite::RET_ERROR);
385 int64_t num_of_inputs = SizeToInt(anf_node->cast<CNodePtr>()->size() - kSizeOne);
386 primitive->set_attr("N", MakeValue<int64_t>(num_of_inputs));
387 primitive->set_attr("inputNums", MakeValue<int64_t>(num_of_inputs));
388
389 // The first input type is used as the type for concat (need to add type check)
390 TypeId make_tuple_type;
391 if (opt::GetDataTypeFromAnfNode(anf_node, &make_tuple_type) != RET_OK) {
392 MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
393 return lite::RET_ERROR;
394 }
395 auto concat_abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(ShapeVector({num_of_inputs})),
396 TypeIdToType(make_tuple_type));
397 concat_node->set_abstract(concat_abstract);
398
399 // set MakeTuple users' input to concat
400 auto make_tuple_users = manager->node_users()[anf_node];
401 for (auto &make_tuple_user : make_tuple_users) {
402 auto post_cnode = make_tuple_user.first->cast<CNodePtr>();
403 MS_CHECK_TRUE_MSG(post_cnode != nullptr, lite::RET_ERROR, "MakeTuple user is null.");
404 manager->SetEdge(post_cnode, GetInputNodeIndex(anf_node, post_cnode) + kSizeOne, concat_node);
405 }
406
407 return lite::RET_OK;
408 }
409
ReplaceShapeTupleGet(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)410 STATUS ScalarOpPass::ReplaceShapeTupleGet(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
411 const FuncGraphManagerPtr &manager) {
412 auto shape_cnode = anf_node->cast<CNodePtr>();
413
414 // Replace Shape by TensorShape
415 auto tensor_shape_node = GenerateTensorShape(func_graph, anf_node);
416 MS_CHECK_TRUE_MSG(tensor_shape_node != nullptr, lite::RET_ERROR, "generate TensorShape node failed.");
417 ShapeVector tensor_shape_shape;
418 auto ret = FetchShapeFromAbstract(tensor_shape_node->abstract(), &tensor_shape_shape);
419 MS_CHECK_TRUE_MSG(ret != RET_ERROR, lite::RET_ERROR, "fetch shape from TensorShape node failed.");
420 auto cast_abstract =
421 abstract::MakeAbstract(std::make_shared<abstract::Shape>(tensor_shape_shape), TypeIdToType(kNumberTypeInt32));
422
423 CNodePtr cast_node = nullptr;
424 auto shape_users = manager->node_users()[shape_cnode];
425 for (auto &shape_user : shape_users) {
426 auto tuple_get_node = shape_user.first->cast<CNodePtr>();
427 if (CheckPrimitiveType(tuple_get_node, prim::kPrimTupleGetItem)) {
428 if (cast_node == nullptr) {
429 cast_node =
430 GenCastNode(func_graph, tensor_shape_node, tensor_shape_node->fullname_with_scope() + "_cast_tensorshape",
431 kNumberTypeInt32, cast_abstract);
432 }
433 auto strided_slice_node = GenerateStridedSlice(func_graph, cast_node, tuple_get_node, manager);
434 MS_CHECK_TRUE_MSG(strided_slice_node != nullptr, lite::RET_ERROR, "generate StridedSlice node failed.");
435
436 CNodePtr tensor_to_scalar = GenerateTensorToScalar(func_graph, strided_slice_node, true);
437 MS_CHECK_TRUE_MSG(tensor_to_scalar != nullptr, lite::RET_ERROR, "generate TensorToScalar node failed.");
438
439 auto tuple_get_users = manager->node_users()[tuple_get_node];
440 for (auto &tuple_get_user : tuple_get_users) {
441 auto post_cnode = tuple_get_user.first->cast<CNodePtr>();
442 MS_CHECK_TRUE_MSG(post_cnode != nullptr, lite::RET_ERROR, "TupleGetItem user is null.");
443 manager->SetEdge(post_cnode, GetInputNodeIndex(tuple_get_node, post_cnode) + kSizeOne, tensor_to_scalar);
444 }
445 }
446 }
447
448 return lite::RET_OK;
449 }
450
RemoveTensorToScalar(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)451 STATUS ScalarOpPass::RemoveTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node,
452 const FuncGraphManagerPtr &manager) {
453 auto tensor_to_scalar_cnode = anf_node->cast<CNodePtr>();
454 auto tensor_to_scalar_users = manager->node_users()[tensor_to_scalar_cnode];
455 auto parent_node = tensor_to_scalar_cnode->input(kIndexOne);
456 for (auto &tensor_to_scalar_user : tensor_to_scalar_users) {
457 auto user_node = tensor_to_scalar_user.first->cast<CNodePtr>();
458 if (CheckPrimitiveType(user_node, prim::kPrimScalarToTensor) || CheckPrimitiveType(user_node, prim::kPrimCast)) {
459 auto child_node_users = manager->node_users()[user_node];
460 for (auto &child_node_user : child_node_users) {
461 auto child_node = child_node_user.first->cast<CNodePtr>();
462 manager->SetEdge(child_node, GetInputNodeIndex(user_node, child_node) + kSizeOne, parent_node);
463 }
464 } else {
465 std::string prim_name = "";
466 (void)GetPrimitiveType(user_node, &prim_name);
467 MS_LOG(ERROR) << "Cannot handle primitive " << prim_name << " after TensorToScalar, please check graph.";
468 return lite::RET_ERROR;
469 }
470 }
471 return lite::RET_OK;
472 }
473
RunScalarOpPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)474 STATUS ScalarOpPass::RunScalarOpPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
475 auto node_list = TopoSort(func_graph->get_return());
476 STATUS status = lite::RET_NO_CHANGE;
477 for (auto &node : node_list) {
478 if (!utils::isa<CNodePtr>(node)) {
479 continue;
480 }
481 // First replace all Scalar ops to tensor equivalents
482 if (CheckPrimitiveType(node, prim::kPrimScalarMul)) {
483 status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimMul);
484 } else if (CheckPrimitiveType(node, prim::kPrimScalarDiv)) {
485 status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimRealDiv);
486 } else if (CheckPrimitiveType(node, prim::kPrimScalarFloorDiv)) {
487 status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimFloorDiv);
488 } else if (CheckPrimitiveType(node, prim::kPrimScalarSub)) {
489 status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimSub);
490 } else if (CheckPrimitiveType(node, prim::kPrimScalarAdd)) {
491 status = this->ReplaceScalarOp(func_graph, node, manager, prim::kPrimAdd);
492 } else if (CheckPrimitiveType(node, prim::kPrimScalarCast)) {
493 MS_LOG(ERROR) << "For models with dynamic input shapes, ScalarCast node conversion has not been supported yet, "
494 "please check cast operations such as \"int(some_var)\" in the front-end code and remove them.";
495 status = lite::RET_NOT_SUPPORT;
496 }
497
498 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
499 MS_LOG(ERROR) << "Failed to run scalar op pass at cnode: " << node->fullname_with_scope();
500 return lite::RET_ERROR;
501 }
502 }
503 return status;
504 }
505
RunMakeTuplePass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)506 STATUS ScalarOpPass::RunMakeTuplePass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
507 auto node_list = TopoSort(func_graph->get_return());
508 auto status = lite::RET_OK;
509 for (auto &node : node_list) {
510 if (!utils::isa<CNodePtr>(node)) {
511 continue;
512 }
513 // Then change MakeTuple's input from a tuple of scalars to a tuple of tensors
514 if (CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
515 status = this->ReplaceMakeTuple(func_graph, node, manager);
516 }
517
518 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
519 MS_LOG(ERROR) << "Failed to run make tuple pass at cnode: " << node->fullname_with_scope();
520 return lite::RET_ERROR;
521 }
522 }
523 return lite::RET_OK;
524 }
525
RunShapeTupleGetPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)526 STATUS ScalarOpPass::RunShapeTupleGetPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
527 auto node_list = TopoSort(func_graph->get_return());
528 auto status = lite::RET_OK;
529 for (auto &node : node_list) {
530 if (!utils::isa<CNodePtr>(node)) {
531 continue;
532 }
533 // Replace Shape+TupleGetItem to TensorShape+StridedSlice
534 if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
535 auto tuple_get_input = node->cast<CNodePtr>()->input(kIndexOne);
536 if (CheckPrimitiveType(tuple_get_input, prim::kPrimShape)) {
537 MS_LOG(INFO) << "Start processing Shape + TupleGetItem pass.";
538 status = this->ReplaceShapeTupleGet(func_graph, tuple_get_input, manager);
539 }
540 }
541
542 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
543 MS_LOG(ERROR) << "Failed to run shape tuple get pass at cnode: " << node->fullname_with_scope();
544 return lite::RET_ERROR;
545 }
546 }
547 return lite::RET_OK;
548 }
549
RunRemoveTensorToScalarPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)550 STATUS ScalarOpPass::RunRemoveTensorToScalarPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
551 auto node_list = TopoSort(func_graph->get_return());
552 auto status = lite::RET_OK;
553 for (auto &node : node_list) {
554 if (!utils::isa<CNodePtr>(node)) {
555 continue;
556 }
557 // Remove TensorToScalar + ScalarToTensor
558 // Remove TensorToScalar + Cast
559 if (CheckPrimitiveType(node, prim::kPrimTensorToScalar)) {
560 MS_LOG(DEBUG) << "Found TensorToScalar, start removing...";
561 status = this->RemoveTensorToScalar(func_graph, node, manager);
562 }
563
564 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
565 MS_LOG(ERROR) << "Failed to run remove TensorToScalar pass at cnode: " << node->fullname_with_scope();
566 return lite::RET_ERROR;
567 }
568 }
569 return lite::RET_OK;
570 }
571
572 /*
573 This pass checks the arithmetic ops have correct infer types when all TensorToScalar/ScalarToTensor ops are removed. If
574 datatypes do not agree, insert cast op.
575 */
RunArithmeticCheckPass(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager)576 STATUS ScalarOpPass::RunArithmeticCheckPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
577 auto node_list = TopoSort(func_graph->get_return());
578 for (auto &node : node_list) {
579 if (!utils::isa<CNodePtr>(node)) {
580 continue;
581 }
582 // Check arithmetic op infer type, insert cast op if two inputs do not agree
583 if (CheckPrimitiveType(node, prim::kPrimMul) || CheckPrimitiveType(node, prim::kPrimDiv) ||
584 CheckPrimitiveType(node, prim::kPrimFloorDiv) || CheckPrimitiveType(node, prim::kPrimRealDiv) ||
585 CheckPrimitiveType(node, prim::kPrimSub)) {
586 auto first_input = node->cast<CNodePtr>()->input(kIndexOne);
587 auto second_input = node->cast<CNodePtr>()->input(kIndexTwo);
588
589 TypeId first_data_type;
590 if (opt::GetDataTypeFromAnfNode(first_input, &first_data_type) != RET_OK) {
591 MS_LOG(ERROR) << "Failed to get arithmetic op first input tensor data type.";
592 return lite::RET_ERROR;
593 }
594 TypeId second_data_type;
595 if (opt::GetDataTypeFromAnfNode(second_input, &second_data_type) != RET_OK) {
596 MS_LOG(ERROR) << "Failed to get arithmetic op second input tensor data type.";
597 return lite::RET_ERROR;
598 }
599 if (first_data_type == second_data_type) {
600 continue;
601 }
602
603 // Insert cast node before second input, and set infer type the same as the first input
604 auto cast_data_type = first_data_type;
605 ShapeVector cast_shape;
606 if (FetchShapeFromAbstract(second_input->abstract(), &cast_shape) != lite::RET_OK) {
607 MS_LOG(ERROR) << "Fetch shape from second input abstract failed!";
608 return lite::RET_ERROR;
609 }
610 auto new_cast_abstract =
611 abstract::MakeAbstract(std::make_shared<abstract::Shape>(cast_shape), TypeIdToType(cast_data_type));
612 auto new_cast_node =
613 GenCastNode(func_graph, second_input, second_input->fullname_with_scope() + "cast_after_second_in",
614 cast_data_type, new_cast_abstract);
615 new_cast_node->set_abstract(new_cast_abstract);
616 manager->SetEdge(node, kIndexTwo, new_cast_node);
617 }
618 }
619 return lite::RET_OK;
620 }
621
Run(const FuncGraphPtr & func_graph)622 bool ScalarOpPass::Run(const FuncGraphPtr &func_graph) {
623 MS_ASSERT(func_graph != nullptr);
624 auto manager = func_graph->manager();
625 MS_CHECK_TRUE_RET(manager != nullptr, false);
626 auto status = RunShapeTupleGetPass(func_graph, manager);
627 MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
628 auto scalar_replace_status = RunScalarOpPass(func_graph, manager);
629 MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
630 status = RunMakeTuplePass(func_graph, manager);
631 MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
632 status = RunRemoveTensorToScalarPass(func_graph, manager);
633 MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
634 if (scalar_replace_status != lite::RET_NO_CHANGE) {
635 status = RunArithmeticCheckPass(func_graph, manager);
636 MS_CHECK_TRUE_RET(status != lite::RET_ERROR, false);
637 }
638
639 return true;
640 }
641 } // namespace mindspore::opt
642