1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 19 20 #include <deque> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <unordered_map> 25 26 #include "ir/anf.h" 27 #include "ir/func_graph.h" 28 #include "frontend/optimizer/optimizer_caller.h" 29 #include "frontend/operator/ops.h" 30 31 namespace mindspore { 32 /* namespace to support opt */ 33 namespace opt { 34 35 // Define the interaction mode between an Optimize pass and Renormalize pass 36 // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed 37 // CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted 38 enum RenormAction : int64_t { FORCE_RENORM = 0, CHECK_RENORM }; 39 40 class Substitution { 41 public: 42 OptimizerCallerPtr transform_; 43 std::string name_; 44 PredicateFuncType predicate_{nullptr}; 45 // An enum to mark this Substitution relation to renormalize pass. 46 RenormAction renorm_action_; 47 // Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others. 48 bool has_priority_pattern_{false}; 49 Substitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)50 Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, 51 const RenormAction &renorm_action, bool has_priority_pattern) 52 : transform_(transform), 53 name_(name), 54 predicate_(predicate), 55 renorm_action_(renorm_action), 56 has_priority_pattern_(has_priority_pattern) {} 57 ~Substitution() = default; 58 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); 59 }; 60 61 using SubstitutionPtr = std::shared_ptr<Substitution>; 62 63 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, 64 const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false); 65 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, 66 const std::vector<PrimitivePtr> &prims, 67 const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false); 68 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, 69 const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM, 70 bool has_priority_pattern = false); 71 72 enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR }; 73 74 class SubstitutionList { 75 public: 76 explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false, 77 bool global_sensitive = false) list_(patterns)78 : list_(patterns), is_once_(is_once), global_sensitive_(global_sensitive) {} 79 ~SubstitutionList() = default; 80 81 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; 82 83 private: 84 bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; 85 bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph, 86 const SubstitutionPtr &sub) const; 87 bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; 88 void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status, 89 const OptimizerPtr &optimizer, size_t space) const; 90 91 std::vector<SubstitutionPtr> list_; 92 // a flag to mark this list of Substitution can only be executed only once 93 bool is_once_{false}; 94 bool global_sensitive_{false}; 95 }; 96 } // namespace opt 97 } // namespace mindspore 98 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ 99