1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include <unordered_map> 23 24 #include "ir/anf.h" 25 #include "ir/func_graph.h" 26 #include "pybind_api/ir/primitive_py.h" 27 #include "ir/graph_utils.h" 28 #include "utils/ms_utils.h" 29 30 #include "pipeline/jit/resource.h" 31 #include "frontend/optimizer/pattern.h" 32 #include "frontend/optimizer/py_pass.h" 33 #include "frontend/optimizer/pass_group.h" 34 35 namespace mindspore { 36 namespace opt { 37 namespace python_pass { 38 class PyPassManager; 39 using PyPassManagerPtr = std::shared_ptr<PyPassManager>; 40 41 enum Phase { PREAD, OPT }; 42 43 class PyPassManager { 44 protected: 45 PyPassManager(); 46 static PyPassManagerPtr global_instance; 47 48 public: 49 // Singletons should not be cloneable and assignable 50 PyPassManager(const PyPassManager &other) = delete; 51 void operator=(const PyPassManager &) = delete; 52 // Access the only global instance 53 static PyPassManagerPtr GetInstance(); 54 virtual ~PyPassManager() = default; 55 void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad, 56 bool run_only_once); 57 void Unregister(const std::string &pass_name); 58 void GenNewParameter(const PatternPtr ¶meter); 59 PassGroupPtr GetPassGroup(Phase phase); GetMatchResult()60 MatchResultPtr GetMatchResult() { return res_; } SetRenorm(bool should_renorm)61 void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } ShouldRenorm()62 bool ShouldRenorm() { return should_renorm_; } SetReOpt(bool should_reopt)63 void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; } ShouldReOpt()64 bool ShouldReOpt() { return should_reopt_; } SetResource(pipeline::ResourcePtr resource)65 void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } GetResource()66 pipeline::ResourcePtr GetResource() { return resource_; } 67 void ClearRes(); ClearPipelineRes()68 void ClearPipelineRes() { 69 resource_ = nullptr; 70 Pattern::reset_gid(); 71 } 72 73 private: 74 bool should_renorm_ = true; 75 bool should_reopt_ = true; 76 MatchResultPtr res_; 77 pipeline::ResourcePtr resource_; 78 static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; 79 }; 80 } // namespace python_pass 81 } // namespace opt 82 } // namespace mindspore 83 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ 84