• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/axis_normalizer.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "backend/common/graph_kernel/adapter/callback_impl.h"
23 #include "backend/common/graph_kernel/graph_kernel_helper.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "ir/scalar.h"
26 #include "ir/tensor.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "utils/anf_utils.h"
30 
31 namespace mindspore::graphkernel {
NormAxis(int64_t x,size_t rank) const32 int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
33 
IsReduce(const AnfNodePtr & node) const34 bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
35   std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin,
36                                               prim::kPrimArgMax, prim::kPrimArgmin};
37   return std::any_of(node_with_axis.begin(), node_with_axis.end(),
38                      [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
39 }
40 
AxisProcess(ValuePtr axis,const size_t rank,ShapeVector * axis_vec) const41 bool AxisNormalizer::AxisProcess(ValuePtr axis, const size_t rank, ShapeVector *axis_vec) const {
42   bool diff = false;
43   if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
44     auto v1 = AnfUtils::GetIntValue(axis);
45     auto v2 = NormAxis(v1, rank);
46     axis_vec->push_back(v2);
47   } else if (axis->isa<ValueSequence>()) {
48     auto vec = axis->cast<ValueSequencePtr>()->value();
49     if (vec.empty()) {
50       diff = true;
51       for (size_t i = 0; i < rank; i++) {
52         axis_vec->push_back(i);
53       }
54     } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
55       for (auto v : vec) {
56         auto v1 = AnfUtils::GetIntValue(v);
57         auto v2 = NormAxis(v1, rank);
58         axis_vec->push_back(v2);
59         diff = diff || (v1 != v2);
60       }
61     }
62   } else if (axis->isa<tensor::Tensor>()) {
63     auto raw_axis_vec = CheckAndConvertUtils::CheckTensorIntValue("axis", axis, "ReduceOp");
64     if (raw_axis_vec.empty()) {
65       diff = true;
66       for (size_t i = 0; i < rank; i++) {
67         axis_vec->push_back(i);
68       }
69     } else {
70       for (auto v1 : raw_axis_vec) {
71         auto v2 = NormAxis(v1, rank);
72         axis_vec->push_back(v2);
73       }
74       // if tensor shape is empty, create a new 1-d tensor
75       auto axis_tensor = axis->cast<tensor::TensorPtr>();
76       diff = axis_tensor->shape_c().empty() || raw_axis_vec != *axis_vec;
77     }
78   }
79 
80   return diff;
81 }
82 
Process(const AnfNodePtr & graph_kernel_node) const83 bool AxisNormalizer::Process(const AnfNodePtr &graph_kernel_node) const {
84   auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node);
85   auto parameters = sub_func_graph->parameters();
86   auto inputs = graph_kernel_node->cast<CNodePtr>()->inputs();
87   std::unordered_map<AnfNodePtr, size_t> param_idx_map;
88   for (size_t i = 0; i < parameters.size(); ++i) {
89     param_idx_map[parameters[i]] = i;
90   }
91   bool changed = false;
92   auto todos = TopoSort(sub_func_graph->get_return());
93   for (auto node : todos) {
94     if (!IsReduce(node)) {
95       continue;
96     }
97     auto cnode = node->cast<CNodePtr>();
98     MS_EXCEPTION_IF_NULL(cnode);
99     const size_t axis_idx = 2;
100     auto axis_node = cnode->input(axis_idx);
101     ValuePtr axis = nullptr;
102     if (axis_node->isa<ValueNode>()) {
103       auto axis_value_node = axis_node->cast<ValueNodePtr>();
104       axis = axis_value_node->value();
105     } else {  // Parameter
106       axis = axis_node->abstract()->BuildValue();
107     }
108     size_t rank = Callback::Instance()->GetInputShape(node, 0).size();
109     if (rank == 0) {
110       // scalar tensor
111       rank = 1;
112     }
113     ShapeVector axis_vec;
114     auto diff = AxisProcess(axis, rank, &axis_vec);
115     if (diff) {
116       changed = true;
117       std::sort(axis_vec.begin(), axis_vec.end());
118       ValuePtr new_axis_value = nullptr;
119       new_axis_value = std::make_shared<tensor::Tensor>(axis_vec);
120       auto new_axis_node = std::make_shared<ValueNode>(new_axis_value);
121       new_axis_node->set_abstract(new_axis_value->ToAbstract());
122       Callback::Instance()->SetBasicNodeKernelInfo(
123         new_axis_node, {{ShapeVector{SizeToLong(axis_vec.size())}, kNumberTypeInt64, "DefaultFormat"}});
124       if (axis_node->isa<ValueNode>()) {
125         cnode->set_input(axis_idx, new_axis_node);
126       } else {
127         auto idx = param_idx_map[axis_node];
128         auto &input_node = inputs[idx + 1];
129         auto input_value_node = input_node->cast<ValueNodePtr>();
130         MS_EXCEPTION_IF_NULL(input_value_node);
131         input_value_node->set_abstract(new_axis_node->abstract());
132         input_value_node->set_value(new_axis_value);
133         axis_node->set_abstract(new_axis_node->abstract());
134       }
135     }
136   }
137   return changed;
138 }
139 
Run(const FuncGraphPtr & func_graph)140 bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
141   MS_EXCEPTION_IF_NULL(func_graph);
142   auto mng = func_graph->manager();
143   MS_EXCEPTION_IF_NULL(mng);
144   bool changed = false;
145   auto todos = TopoSort(func_graph->get_return());
146   for (auto node : todos) {
147     if (common::AnfAlgo::IsGraphKernel(node)) {
148       changed = Process(node) || changed;
149     }
150   }
151   if (changed) {
152     mng->RemoveRoots();
153     mng->KeepRoots({func_graph});
154   }
155   return changed;
156 }
157 }  // namespace mindspore::graphkernel
158