1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H 19 20 #include <algorithm> 21 #include <memory> 22 #include <vector> 23 #include <string> 24 25 #include "frontend/optimizer/optimizer_caller.h" 26 #include "mindspore/core/ops/array_ops.h" 27 #include "frontend/optimizer/anf_visitor.h" 28 #include "frontend/operator/ops.h" 29 #include "frontend/optimizer/irpass.h" 30 #include "frontend/optimizer/optimizer.h" 31 #include "include/common/utils/utils.h" 32 33 namespace mindspore { 34 namespace opt { 35 namespace irpass { 36 // {prim::kPrimStack, {prim::kPrimUnstack, X}} => X 37 // prim::kPrimUnstack and prim::kPrimStack should have same attribute value of kAttrNum and kAttrAxis. 38 class StackUnstackEliminator : public AnfVisitor { 39 public: operator()40 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 41 Reset(); 42 43 if (!IsPrimitiveCNode(node, prim::kPrimUnstack)) { 44 return nullptr; 45 } 46 47 if (!FetchUnstackAttrs(node)) { 48 return nullptr; 49 } 50 AnfVisitor::Match(prim::kPrimUnstack, {IsCNode})(node); 51 52 if (is_match_) { 53 return stack_->input(1); 54 } 55 return nullptr; 56 } 57 Visit(const CNodePtr & cnode)58 void Visit(const CNodePtr &cnode) override { 59 if (IsPrimitiveCNode(cnode, prim::kPrimStack)) { 60 auto prim = GetCNodePrimitive(cnode); 61 auto num_val = prim->GetAttr(kAttrNum); 62 // Stack may not be inferred and do not have attribute axis. 63 if (num_val == nullptr) { 64 return; 65 } 66 auto axis_val = prim->GetAttr(kAttrAxis); 67 MS_EXCEPTION_IF_NULL(axis_val); 68 auto num = dyn_cast<Int64Imm>(num_val)->value(); 69 auto axis = dyn_cast<Int64Imm>(axis_val)->value(); 70 if (num == num_ && axis == axis_) { 71 is_match_ = true; 72 stack_ = cnode; 73 } 74 } 75 } 76 FetchUnstackAttrs(const AnfNodePtr & node)77 bool FetchUnstackAttrs(const AnfNodePtr &node) { 78 auto prim = GetCNodePrimitive(node); 79 auto num_val = prim->GetAttr(kAttrNum); 80 // UnStack may not be inferred and do not have attribute axis. 81 if (num_val == nullptr || num_val->isa<None>()) { 82 return false; 83 } 84 auto axis_val = prim->GetAttr(kAttrAxis); 85 MS_EXCEPTION_IF_NULL(axis_val); 86 num_ = dyn_cast<Int64Imm>(num_val)->value(); 87 axis_ = dyn_cast<Int64Imm>(axis_val)->value(); 88 return true; 89 } 90 Reset()91 void Reset() { 92 is_match_ = false; 93 num_ = 0; 94 axis_ = 0; 95 stack_ = nullptr; 96 } 97 98 private: 99 bool is_match_{false}; 100 int64_t num_{0}; 101 int64_t axis_{0}; 102 CNodePtr stack_{nullptr}; 103 }; 104 } // namespace irpass 105 } // namespace opt 106 } // namespace mindspore 107 108 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STACK_UNSTACK_ELIMINATE_H 109