• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 #include <string>
24 
25 #include "frontend/optimizer/optimizer_caller.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "frontend/optimizer/anf_visitor.h"
28 #include "frontend/operator/ops.h"
29 #include "frontend/optimizer/irpass.h"
30 #include "frontend/optimizer/optimizer.h"
31 #include "include/common/utils/utils.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace irpass {
36 // {prim::kPrimStack, {prim::kPrimUnstack, X}} => X
37 // prim::kPrimUnstack and prim::kPrimStack should have same attribute value of kAttrNum and kAttrAxis.
38 class StackUnstackEliminator : public AnfVisitor {
39  public:
operator()40   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
41     Reset();
42 
43     if (!IsPrimitiveCNode(node, prim::kPrimUnstack)) {
44       return nullptr;
45     }
46 
47     if (!FetchUnstackAttrs(node)) {
48       return nullptr;
49     }
50     AnfVisitor::Match(prim::kPrimUnstack, {IsCNode})(node);
51 
52     if (is_match_) {
53       return stack_->input(1);
54     }
55     return nullptr;
56   }
57 
Visit(const CNodePtr & cnode)58   void Visit(const CNodePtr &cnode) override {
59     if (IsPrimitiveCNode(cnode, prim::kPrimStack)) {
60       auto prim = GetCNodePrimitive(cnode);
61       auto num_val = prim->GetAttr(kAttrNum);
62       // Stack may not be inferred and do not have attribute axis.
63       if (num_val == nullptr) {
64         return;
65       }
66       auto axis_val = prim->GetAttr(kAttrAxis);
67       MS_EXCEPTION_IF_NULL(axis_val);
68       auto num = dyn_cast<Int64Imm>(num_val)->value();
69       auto axis = dyn_cast<Int64Imm>(axis_val)->value();
70       if (num == num_ && axis == axis_) {
71         is_match_ = true;
72         stack_ = cnode;
73       }
74     }
75   }
76 
FetchUnstackAttrs(const AnfNodePtr & node)77   bool FetchUnstackAttrs(const AnfNodePtr &node) {
78     auto prim = GetCNodePrimitive(node);
79     auto num_val = prim->GetAttr(kAttrNum);
80     // UnStack may not be inferred and do not have attribute axis.
81     if (num_val == nullptr || num_val->isa<None>()) {
82       return false;
83     }
84     auto axis_val = prim->GetAttr(kAttrAxis);
85     MS_EXCEPTION_IF_NULL(axis_val);
86     num_ = dyn_cast<Int64Imm>(num_val)->value();
87     axis_ = dyn_cast<Int64Imm>(axis_val)->value();
88     return true;
89   }
90 
Reset()91   void Reset() {
92     is_match_ = false;
93     num_ = 0;
94     axis_ = 0;
95     stack_ = nullptr;
96   }
97 
98  private:
99   bool is_match_{false};
100   int64_t num_{0};
101   int64_t axis_{0};
102   CNodePtr stack_{nullptr};
103 };
104 }  // namespace irpass
105 }  // namespace opt
106 }  // namespace mindspore
107 
108 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H
109