• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_
17 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23 
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "pybind_api/ir/primitive_py.h"
27 #include "ir/graph_utils.h"
28 #include "utils/ms_utils.h"
29 
30 #include "pipeline/jit/resource.h"
31 #include "frontend/optimizer/pattern.h"
32 #include "frontend/optimizer/py_pass.h"
33 #include "frontend/optimizer/pass_group.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace python_pass {
38 class PyPassManager;
39 using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
40 
41 enum Phase { PREAD, OPT };
42 
43 class PyPassManager {
44  protected:
45   PyPassManager();
46   static PyPassManagerPtr global_instance;
47 
48  public:
49   // Singletons should not be cloneable and assignable
50   PyPassManager(const PyPassManager &other) = delete;
51   void operator=(const PyPassManager &) = delete;
52   // Access the only global instance
53   static PyPassManagerPtr GetInstance();
54   virtual ~PyPassManager() = default;
55   void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
56                 bool run_only_once);
57   void Unregister(const std::string &pass_name);
58   void GenNewParameter(const PatternPtr &parameter);
59   PassGroupPtr GetPassGroup(Phase phase);
GetMatchResult()60   MatchResultPtr GetMatchResult() { return res_; }
SetRenorm(bool should_renorm)61   void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
ShouldRenorm()62   bool ShouldRenorm() { return should_renorm_; }
SetReOpt(bool should_reopt)63   void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; }
ShouldReOpt()64   bool ShouldReOpt() { return should_reopt_; }
SetResource(pipeline::ResourcePtr resource)65   void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
GetResource()66   pipeline::ResourcePtr GetResource() { return resource_; }
67   void ClearRes();
ClearPipelineRes()68   void ClearPipelineRes() {
69     resource_ = nullptr;
70     Pattern::reset_gid();
71   }
72 
73  private:
74   bool should_renorm_ = true;
75   bool should_reopt_ = true;
76   MatchResultPtr res_;
77   pipeline::ResourcePtr resource_;
78   static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
79 };
80 }  // namespace python_pass
81 }  // namespace opt
82 }  // namespace mindspore
83 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_
84