• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/axis_normalizer.h"
17 
18 #include "ir/scalar.h"
19 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 
22 namespace mindspore {
23 namespace opt {
NormAxis(int64_t x,size_t rank) const24 int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
25 
IsReduce(const AnfNodePtr & node) const26 bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
27   std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin,
28                                               prim::kPrimArgMax, prim::kPrimArgMin};
29   return std::any_of(node_with_axis.begin(), node_with_axis.end(),
30                      [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
31 }
32 
Process(const FuncGraphPtr & func_graph) const33 bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
34   bool changed = false;
35   auto todos = TopoSort(func_graph->get_return());
36   for (auto node : todos) {
37     if (!IsReduce(node)) {
38       continue;
39     }
40     if (auto primitive = GetCNodePrimitive(node); primitive != nullptr && primitive->HasAttr(kAttrAxis)) {
41       auto axis = primitive->GetAttr(kAttrAxis);
42       size_t rank = AnfAlgo::GetInputDeviceShape(node, 0).size();
43       if (rank == 0) {
44         // scalar tensor
45         rank = 1;
46       }
47       bool diff = false;
48       ShapeVector axis_vec;
49       if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
50         auto v1 = GetValue<int64_t>(axis);
51         auto v2 = NormAxis(v1, rank);
52         axis_vec.push_back(v2);
53         diff = true;
54       } else if (axis->isa<ValueSequeue>()) {
55         auto vec = axis->cast<ValueSequeuePtr>()->value();
56         if (vec.empty()) {
57           diff = true;
58           for (size_t i = 0; i < rank; i++) {
59             axis_vec.push_back(i);
60           }
61         } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
62           for (auto v : vec) {
63             auto v1 = GetValue<int64_t>(v);
64             auto v2 = NormAxis(v1, rank);
65             axis_vec.push_back(v2);
66             diff = diff || (v1 != v2);
67           }
68         }
69       }
70       if (diff) {
71         changed = true;
72         std::sort(axis_vec.begin(), axis_vec.end());
73         SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
74       }
75     }
76   }
77   return changed;
78 }
79 
Run(const FuncGraphPtr & func_graph)80 bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
81   MS_EXCEPTION_IF_NULL(func_graph);
82   bool changed = false;
83   auto todos = TopoSort(func_graph->get_return());
84   for (auto node : todos) {
85     if (AnfAlgo::IsGraphKernel(node)) {
86       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
87       changed = Process(sub_func_graph) || changed;
88     }
89   }
90   return changed;
91 }
92 }  // namespace opt
93 }  // namespace mindspore
94