1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 19 20 #include <vector> 21 #include <algorithm> 22 #include <memory> 23 24 #include "frontend/optimizer/irpass.h" 25 #include "mindspore/core/ops/sequence_ops.h" 26 #include "mindspore/core/ops/array_ops.h" 27 #include "mindspore/core/ops/framework_ops.h" 28 #include "frontend/optimizer/optimizer.h" 29 #include "frontend/optimizer/anf_visitor.h" 30 #include "frontend/operator/ops.h" 31 #include "utils/anf_utils.h" 32 33 namespace mindspore { 34 namespace opt { 35 namespace irpass { 36 // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> 37 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} 38 // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> 39 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} 40 class MergeAddN : public AnfVisitor { 41 public: operator()42 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 43 Reset(); 44 mng_ = optimizer->manager(); 45 is_outer_ = true; 46 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); 47 // do not hold this manager 48 mng_ = nullptr; 49 if (!is_match_ || node->func_graph() == nullptr) { 50 return nullptr; 51 } 52 addn_nodes_.push_back(node); 53 54 auto cnode = node->cast<CNodePtr>(); 55 auto addn = NewValueNode(GetValueNode(cnode->input(0))); 56 57 // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} 58 (void)args_.insert(args_.cbegin(), NewValueNode(prim::kPrimMakeTuple)); 59 auto fg = node->func_graph(); 60 auto make_node = fg->NewCNode(args_); 61 62 auto new_node = fg->NewCNode({addn, make_node}); 63 UpdateDumpFlag(new_node); 64 new_node->AddFusedDebugInfoList(addn_nodes_); 65 return new_node; 66 } 67 Visit(const CNodePtr & cnode)68 void Visit(const CNodePtr &cnode) override { 69 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { 70 return; 71 } 72 73 auto &inputs = cnode->inputs(); 74 75 if (is_outer_) { 76 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); 77 78 is_outer_ = false; 79 is_inner_ = true; 80 81 // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} 82 const auto &first_input = inputs.at(1); 83 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input); 84 if (is_match_) { 85 if (!is_unique(first_input)) { 86 is_match_ = false; 87 return; 88 } 89 90 if (!IsStateEquivalent(cnode, first_input)) { 91 is_match_ = false; 92 return; 93 } 94 95 addn_nodes_.push_back(first_input); 96 (void)Ys_.erase(Ys_.cbegin()); 97 (void)std::copy(Xs_.cbegin(), Xs_.cend(), std::back_inserter(args_)); 98 (void)std::copy(Ys_.cbegin(), Ys_.cend(), std::back_inserter(args_)); 99 return; 100 } 101 102 // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} 103 const auto &last_input = inputs.back(); 104 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input); 105 if (is_match_) { 106 if (!is_unique(last_input)) { 107 is_match_ = false; 108 return; 109 } 110 111 if (!IsStateEquivalent(cnode, last_input)) { 112 is_match_ = false; 113 return; 114 } 115 116 addn_nodes_.push_back(last_input); 117 Ys_.pop_back(); 118 (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); 119 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); 120 return; 121 } 122 123 return; 124 } 125 126 if (is_inner_) { 127 is_match_ = true; 128 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); 129 } 130 } 131 is_unique(const AnfNodePtr & node)132 bool is_unique(const AnfNodePtr &node) { 133 auto &node_users = mng_->node_users(); 134 if (node_users.find(node) == node_users.end()) { 135 return false; 136 } 137 138 size_t n_use = node_users[node].size(); 139 return n_use == 1; 140 } 141 Reset()142 void Reset() { 143 Xs_.clear(); 144 Ys_.clear(); 145 args_.clear(); 146 addn_nodes_.clear(); 147 is_inner_ = false; 148 is_outer_ = false; 149 is_match_ = false; 150 } 151 UpdateDumpFlag(const AnfNodePtr & node)152 void UpdateDumpFlag(const AnfNodePtr &node) { 153 if (node == nullptr) { 154 return; 155 } 156 for (const auto &addn : addn_nodes_) { 157 if (AnfUtils::GetDumpFlag(addn)) { 158 AnfUtils::SetDumpFlag(node); 159 return; 160 } 161 } 162 } 163 164 private: 165 FuncGraphManagerPtr mng_{nullptr}; 166 std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}, addn_nodes_{}; 167 bool is_inner_{false}, is_outer_{false}, is_match_{false}; 168 }; 169 170 // {PrimAddN, {kPrimMakeTuple, Xs}} 171 class AddNZeroFilter : public AnfVisitor { 172 public: operator()173 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 174 Reset(); 175 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); 176 177 if (filtered_Xs_.empty() || node->func_graph() == nullptr) { 178 return nullptr; 179 } 180 181 // if only two node in filtered_nodes, {make_tuple, x}. return x. 182 constexpr auto input_size = 2; 183 if (filtered_Xs_.size() == input_size) { 184 return filtered_Xs_[1]; 185 } 186 187 // if only one node in filtered_nodes, all node is zerolike, return one of the input. 188 if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { 189 return Xs_[0]; 190 } 191 192 if (!has_zero_like_) { 193 return nullptr; 194 } 195 196 auto cnode = node->cast<CNodePtr>(); 197 auto addn = NewValueNode(GetValueNode(cnode->input(0))); 198 auto fg = node->func_graph(); 199 auto make_tuple = fg->NewCNode(filtered_Xs_); 200 return fg->NewCNode({addn, make_tuple}); 201 } 202 Visit(const CNodePtr & cnode)203 void Visit(const CNodePtr &cnode) override { 204 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { 205 return; 206 } 207 208 auto &inputs = cnode->inputs(); 209 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); 210 211 // {kPrimMakeTuple, X1, X2, ...} 212 filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); 213 for (auto &x : Xs_) { 214 if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { 215 filtered_Xs_.push_back(x); 216 } else { 217 has_zero_like_ = true; 218 } 219 } 220 } 221 Reset()222 void Reset() { 223 Xs_.clear(); 224 filtered_Xs_.clear(); 225 has_zero_like_ = false; 226 } 227 228 private: 229 std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; 230 bool has_zero_like_{false}; 231 }; 232 233 // {PrimAddN, {kPrimMakeTuple, Xs}} 234 class AddNCheckDump : public AnfVisitor { 235 public: operator()236 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 237 Reset(); 238 AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); 239 240 // Only handle gradient addn. 241 if (node->scope()->name().find("Gradients/") != 0) { 242 return nullptr; 243 } 244 245 if (set_dump_) { 246 AnfUtils::SetDumpFlag(node); 247 } 248 249 return nullptr; 250 } 251 Visit(const CNodePtr & cnode)252 void Visit(const CNodePtr &cnode) override { 253 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { 254 return; 255 } 256 if (cnode->size() < kSizeThree) { 257 return; 258 } 259 260 // When all of inputs has dump flag, we need set dump flag for AddN. 261 set_dump_ = true; 262 for (size_t i = 1; i < cnode->size(); ++i) { 263 auto input = cnode->input(i); 264 MS_EXCEPTION_IF_NULL(input); 265 if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem) || IsPrimitiveCNode(input, prim::kPrimDepend)) { 266 input = input->cast<CNodePtr>()->input(kIndexOne); 267 } 268 if (!input->isa<CNode>() || !AnfUtils::GetDumpFlag(input)) { 269 set_dump_ = false; 270 break; 271 } 272 } 273 } 274 Reset()275 void Reset() { set_dump_ = false; } 276 277 private: 278 bool set_dump_{false}; 279 }; 280 } // namespace irpass 281 } // namespace opt 282 } // namespace mindspore 283 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ 284