1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/axis_normalizer.h"
17
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "backend/common/graph_kernel/adapter/callback_impl.h"
23 #include "backend/common/graph_kernel/graph_kernel_helper.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "ir/scalar.h"
26 #include "ir/tensor.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "utils/anf_utils.h"
30
31 namespace mindspore::graphkernel {
NormAxis(int64_t x,size_t rank) const32 int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
33
IsReduce(const AnfNodePtr & node) const34 bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
35 std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin,
36 prim::kPrimArgMax, prim::kPrimArgmin};
37 return std::any_of(node_with_axis.begin(), node_with_axis.end(),
38 [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
39 }
40
AxisProcess(ValuePtr axis,const size_t rank,ShapeVector * axis_vec) const41 bool AxisNormalizer::AxisProcess(ValuePtr axis, const size_t rank, ShapeVector *axis_vec) const {
42 bool diff = false;
43 if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
44 auto v1 = AnfUtils::GetIntValue(axis);
45 auto v2 = NormAxis(v1, rank);
46 axis_vec->push_back(v2);
47 } else if (axis->isa<ValueSequence>()) {
48 auto vec = axis->cast<ValueSequencePtr>()->value();
49 if (vec.empty()) {
50 diff = true;
51 for (size_t i = 0; i < rank; i++) {
52 axis_vec->push_back(i);
53 }
54 } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
55 for (auto v : vec) {
56 auto v1 = AnfUtils::GetIntValue(v);
57 auto v2 = NormAxis(v1, rank);
58 axis_vec->push_back(v2);
59 diff = diff || (v1 != v2);
60 }
61 }
62 } else if (axis->isa<tensor::Tensor>()) {
63 auto raw_axis_vec = CheckAndConvertUtils::CheckTensorIntValue("axis", axis, "ReduceOp");
64 if (raw_axis_vec.empty()) {
65 diff = true;
66 for (size_t i = 0; i < rank; i++) {
67 axis_vec->push_back(i);
68 }
69 } else {
70 for (auto v1 : raw_axis_vec) {
71 auto v2 = NormAxis(v1, rank);
72 axis_vec->push_back(v2);
73 }
74 // if tensor shape is empty, create a new 1-d tensor
75 auto axis_tensor = axis->cast<tensor::TensorPtr>();
76 diff = axis_tensor->shape_c().empty() || raw_axis_vec != *axis_vec;
77 }
78 }
79
80 return diff;
81 }
82
Process(const AnfNodePtr & graph_kernel_node) const83 bool AxisNormalizer::Process(const AnfNodePtr &graph_kernel_node) const {
84 auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node);
85 auto parameters = sub_func_graph->parameters();
86 auto inputs = graph_kernel_node->cast<CNodePtr>()->inputs();
87 std::unordered_map<AnfNodePtr, size_t> param_idx_map;
88 for (size_t i = 0; i < parameters.size(); ++i) {
89 param_idx_map[parameters[i]] = i;
90 }
91 bool changed = false;
92 auto todos = TopoSort(sub_func_graph->get_return());
93 for (auto node : todos) {
94 if (!IsReduce(node)) {
95 continue;
96 }
97 auto cnode = node->cast<CNodePtr>();
98 MS_EXCEPTION_IF_NULL(cnode);
99 const size_t axis_idx = 2;
100 auto axis_node = cnode->input(axis_idx);
101 ValuePtr axis = nullptr;
102 if (axis_node->isa<ValueNode>()) {
103 auto axis_value_node = axis_node->cast<ValueNodePtr>();
104 axis = axis_value_node->value();
105 } else { // Parameter
106 axis = axis_node->abstract()->BuildValue();
107 }
108 size_t rank = Callback::Instance()->GetInputShape(node, 0).size();
109 if (rank == 0) {
110 // scalar tensor
111 rank = 1;
112 }
113 ShapeVector axis_vec;
114 auto diff = AxisProcess(axis, rank, &axis_vec);
115 if (diff) {
116 changed = true;
117 std::sort(axis_vec.begin(), axis_vec.end());
118 ValuePtr new_axis_value = nullptr;
119 new_axis_value = std::make_shared<tensor::Tensor>(axis_vec);
120 auto new_axis_node = std::make_shared<ValueNode>(new_axis_value);
121 new_axis_node->set_abstract(new_axis_value->ToAbstract());
122 Callback::Instance()->SetBasicNodeKernelInfo(
123 new_axis_node, {{ShapeVector{SizeToLong(axis_vec.size())}, kNumberTypeInt64, "DefaultFormat"}});
124 if (axis_node->isa<ValueNode>()) {
125 cnode->set_input(axis_idx, new_axis_node);
126 } else {
127 auto idx = param_idx_map[axis_node];
128 auto &input_node = inputs[idx + 1];
129 auto input_value_node = input_node->cast<ValueNodePtr>();
130 MS_EXCEPTION_IF_NULL(input_value_node);
131 input_value_node->set_abstract(new_axis_node->abstract());
132 input_value_node->set_value(new_axis_value);
133 axis_node->set_abstract(new_axis_node->abstract());
134 }
135 }
136 }
137 return changed;
138 }
139
Run(const FuncGraphPtr & func_graph)140 bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
141 MS_EXCEPTION_IF_NULL(func_graph);
142 auto mng = func_graph->manager();
143 MS_EXCEPTION_IF_NULL(mng);
144 bool changed = false;
145 auto todos = TopoSort(func_graph->get_return());
146 for (auto node : todos) {
147 if (common::AnfAlgo::IsGraphKernel(node)) {
148 changed = Process(node) || changed;
149 }
150 }
151 if (changed) {
152 mng->RemoveRoots();
153 mng->KeepRoots({func_graph});
154 }
155 return changed;
156 }
157 } // namespace mindspore::graphkernel
158