1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/axis_normalizer.h"
17
18 #include "ir/scalar.h"
19 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21
22 namespace mindspore {
23 namespace opt {
NormAxis(int64_t x,size_t rank) const24 int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
25
IsReduce(const AnfNodePtr & node) const26 bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
27 std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin,
28 prim::kPrimArgMax, prim::kPrimArgMin};
29 return std::any_of(node_with_axis.begin(), node_with_axis.end(),
30 [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
31 }
32
Process(const FuncGraphPtr & func_graph) const33 bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
34 bool changed = false;
35 auto todos = TopoSort(func_graph->get_return());
36 for (auto node : todos) {
37 if (!IsReduce(node)) {
38 continue;
39 }
40 if (auto primitive = GetCNodePrimitive(node); primitive != nullptr && primitive->HasAttr(kAttrAxis)) {
41 auto axis = primitive->GetAttr(kAttrAxis);
42 size_t rank = AnfAlgo::GetInputDeviceShape(node, 0).size();
43 if (rank == 0) {
44 // scalar tensor
45 rank = 1;
46 }
47 bool diff = false;
48 ShapeVector axis_vec;
49 if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
50 auto v1 = GetValue<int64_t>(axis);
51 auto v2 = NormAxis(v1, rank);
52 axis_vec.push_back(v2);
53 diff = true;
54 } else if (axis->isa<ValueSequeue>()) {
55 auto vec = axis->cast<ValueSequeuePtr>()->value();
56 if (vec.empty()) {
57 diff = true;
58 for (size_t i = 0; i < rank; i++) {
59 axis_vec.push_back(i);
60 }
61 } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
62 for (auto v : vec) {
63 auto v1 = GetValue<int64_t>(v);
64 auto v2 = NormAxis(v1, rank);
65 axis_vec.push_back(v2);
66 diff = diff || (v1 != v2);
67 }
68 }
69 }
70 if (diff) {
71 changed = true;
72 std::sort(axis_vec.begin(), axis_vec.end());
73 SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
74 }
75 }
76 }
77 return changed;
78 }
79
Run(const FuncGraphPtr & func_graph)80 bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
81 MS_EXCEPTION_IF_NULL(func_graph);
82 bool changed = false;
83 auto todos = TopoSort(func_graph->get_return());
84 for (auto node : todos) {
85 if (AnfAlgo::IsGraphKernel(node)) {
86 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
87 changed = Process(sub_func_graph) || changed;
88 }
89 }
90 return changed;
91 }
92 } // namespace opt
93 } // namespace mindspore
94