1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 19 20 #include <vector> 21 #include <algorithm> 22 #include <memory> 23 24 #include "frontend/optimizer/irpass.h" 25 #include "frontend/optimizer/optimizer.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 #include "frontend/operator/ops.h" 28 29 namespace mindspore { 30 namespace opt { 31 namespace irpass { 32 // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> 33 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} 34 // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> 35 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} 36 class MergeAddN : public AnfVisitor { 37 public: operator()38 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 39 Reset(); 40 mng_ = optimizer->manager(); 41 is_outer_ = true; 42 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); 43 // do not hold this manager 44 mng_ = nullptr; 45 if (!is_match_ || node->func_graph() == nullptr) { 46 return nullptr; 47 } 48 49 auto cnode = node->cast<CNodePtr>(); 50 auto addn = NewValueNode(GetValueNode(cnode->input(0))); 51 52 // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} 53 (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); 54 auto fg = node->func_graph(); 55 auto make_node = fg->NewCNode(args_); 56 57 return fg->NewCNode({addn, make_node}); 58 } 59 Visit(const CNodePtr & cnode)60 void Visit(const CNodePtr &cnode) override { 61 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { 62 return; 63 } 64 65 auto &inputs = cnode->inputs(); 66 67 if (is_outer_) { 68 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); 69 70 is_outer_ = false; 71 is_inner_ = true; 72 73 // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} 74 const auto &first_input = inputs.at(1); 75 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input); 76 if (is_match_) { 77 if (!is_unique(first_input)) { 78 is_match_ = false; 79 return; 80 } 81 82 if (!IsStateEquivalent(cnode, first_input)) { 83 is_match_ = false; 84 return; 85 } 86 87 (void)Ys_.erase(Ys_.begin()); 88 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); 89 (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); 90 return; 91 } 92 93 // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} 94 const auto &last_input = inputs.back(); 95 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input); 96 if (is_match_) { 97 if (!is_unique(last_input)) { 98 is_match_ = false; 99 return; 100 } 101 102 if (!IsStateEquivalent(cnode, last_input)) { 103 is_match_ = false; 104 return; 105 } 106 107 Ys_.pop_back(); 108 (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); 109 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); 110 return; 111 } 112 113 return; 114 } 115 116 if (is_inner_) { 117 is_match_ = true; 118 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); 119 } 120 } 121 is_unique(const AnfNodePtr & node)122 bool is_unique(const AnfNodePtr &node) { 123 auto &node_users = mng_->node_users(); 124 if (node_users.find(node) == node_users.end()) { 125 return false; 126 } 127 128 size_t n_use = node_users[node].size(); 129 return n_use == 1; 130 } 131 Reset()132 void Reset() { 133 Xs_.clear(); 134 Ys_.clear(); 135 args_.clear(); 136 is_inner_ = false; 137 is_outer_ = false; 138 is_match_ = false; 139 } 140 141 private: 142 FuncGraphManagerPtr mng_{nullptr}; 143 std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}; 144 bool is_inner_{false}, is_outer_{false}, is_match_{false}; 145 }; 146 147 // {PrimAddN, {kPrimMakeTuple, Xs}} 148 class AddNZeroFilter : public AnfVisitor { 149 public: operator()150 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 151 Reset(); 152 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); 153 154 if (filtered_Xs_.empty() || node->func_graph() == nullptr) { 155 return nullptr; 156 } 157 158 // if only two node in filtered_nodes, {make_tuple, x}. return x. 159 if (filtered_Xs_.size() == 2) { 160 return filtered_Xs_[1]; 161 } 162 163 // if only one node in filtered_nodes, all node is zerolike, return one of the input. 164 if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { 165 return Xs_[0]; 166 } 167 168 if (!has_zero_like_) { 169 return nullptr; 170 } 171 172 auto cnode = node->cast<CNodePtr>(); 173 auto addn = NewValueNode(GetValueNode(cnode->input(0))); 174 auto fg = node->func_graph(); 175 auto make_tuple = fg->NewCNode(filtered_Xs_); 176 return fg->NewCNode({addn, make_tuple}); 177 } 178 Visit(const CNodePtr & cnode)179 void Visit(const CNodePtr &cnode) override { 180 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { 181 return; 182 } 183 184 auto &inputs = cnode->inputs(); 185 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); 186 187 // {kPrimMakeTuple, X1, X2, ...} 188 filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); 189 for (auto &x : Xs_) { 190 if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { 191 filtered_Xs_.push_back(x); 192 } else { 193 has_zero_like_ = true; 194 } 195 } 196 } 197 Reset()198 void Reset() { 199 Xs_.clear(); 200 filtered_Xs_.clear(); 201 has_zero_like_ = false; 202 } 203 204 private: 205 std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; 206 bool has_zero_like_{false}; 207 }; 208 } // namespace irpass 209 } // namespace opt 210 } // namespace mindspore 211 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 212