• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_
19 
20 #include <algorithm>
21 #include <functional>
22 #include <iterator>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #include <map>
27 #include <utility>
28 #include <initializer_list>
29 
30 #include "debug/draw.h"
31 #include "debug/anf_ir_dump.h"
32 #include "debug/anf_ir_utils.h"
33 #include "debug/trace.h"
34 #include "frontend/optimizer/opt.h"
35 #include "pipeline/jit/resource.h"
36 #include "pipeline/jit/action.h"
37 #include "utils/ms_context.h"
38 
39 namespace mindspore {
40 namespace opt {
41 using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer)>;
42 
43 class OptPassConfig {
44  public:
OptPassConfig(const OptimizeGraphFunc & func)45   explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
46   explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
list_(list)47       : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
48   OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
list_(list)49       : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
50   ~OptPassConfig() = default;
51 
list()52   const std::vector<SubstitutionPtr> &list() const { return list_; }
func()53   const OptimizeGraphFunc &func() const { return func_; }
54 
Renormalize()55   static OptPassConfig Renormalize() { return OptPassConfig(); }
is_renormalize()56   const bool is_renormalize() const { return is_renormalize_; }
57 
is_once()58   const bool is_once() const { return is_once_; }
59 
global_sensitive()60   const bool global_sensitive() const { return global_sensitive_; }
61 
62  private:
OptPassConfig()63   OptPassConfig() : is_renormalize_(true) {}
64 
65   OptimizeGraphFunc func_;
66   std::vector<SubstitutionPtr> list_;
67   bool is_renormalize_{false};
68   bool is_once_{false};
69   bool global_sensitive_{false};
70 };
71 
72 class OptPass {
73  public:
OptPass(const OptimizeGraphFunc & func)74   explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {}
75   ~OptPass() = default;
76 
operator()77   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
78     return pass_func_(func_graph, optimizer);
79   }
80 
Renormalize()81   static OptPass Renormalize() { return OptPass(); }
is_renormalize()82   const bool is_renormalize() const { return is_renormalize_; }
83 
84  private:
OptPass()85   OptPass() : is_renormalize_(true) {}
86 
87   OptimizeGraphFunc pass_func_;
88   bool is_renormalize_{false};
89 };
90 using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
91 
92 class Optimizer : public std::enable_shared_from_this<Optimizer> {
93  public:
94   Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true)
name_(name)95       : name_(name),
96         resource_(resource_ptr),
97         run_only_once_(false),
98         is_watch_renormalize_(false),
99         is_enable_(true),
100         is_untyped_generated_(false),
101         traverse_nodes_first_(traverse_nodes_first) {}
102   virtual ~Optimizer() = default;
103 
Init(const OptPassGroupMap & passes,bool run_only_once)104   void Init(const OptPassGroupMap &passes, bool run_only_once) {
105     run_only_once_ = run_only_once;
106     is_watch_renormalize_ = false;
107     is_untyped_generated_ = false;
108     is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG);
109 
110     for (auto &iter : passes) {
111       const std::string &name = iter.first;
112       pass_names_.push_back(name);
113 
114       const OptPassConfig &config = iter.second;
115       if (config.is_renormalize()) {
116         passes_.push_back(OptPass::Renormalize());
117         continue;
118       }
119 
120       if (config.list().size() > 0) {
121         OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive());
122         passes_.push_back(OptPass(func));
123         continue;
124       }
125 
126       passes_.push_back(OptPass(config.func()));
127     }
128 
129     if (passes_.size() == 1) {
130       run_only_once_ = true;
131     }
132   }
133 
134   static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
135                                                   const OptPassGroupMap &passes, bool run_only_once = false,
136                                                   bool watch_renormalize = false, bool traverse_nodes_first = true) {
137     OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first);
138     optimizer->Init(passes, run_only_once);
139     if (watch_renormalize) {
140       optimizer->enable_watch_renormalize();
141     }
142     return optimizer;
143   }
144 
145   FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
146     if (!is_enable_) {
147       return func_graph;
148     }
149     // Optimizer step counter;
150     int64_t counter = 1;
151     bool changes = true;
152     // If no changes since last renormalization, then no need to do the renormalization again.
153     // Set the initial value to true, so the renormalization can be executed once if it's the
154     // only pass.
155     bool changes_since_last_renorm = true;
156 
157     while (changes) {
158       changes = false;
159       auto run_runc = [&counter, &func_graph, &changes, &changes_since_last_renorm, use_profile, this]() {
160         for (size_t i = 0; i < passes_.size(); ++i) {
161           const OptPass &opt = passes_[i];
162           CurPass_ = {counter, pass_names_[i]};
163           auto opt_func = [&func_graph, &changes, &opt, &changes_since_last_renorm, this]() {
164             if (opt.is_renormalize()) {
165               if (!changes_since_last_renorm) {
166                 return;
167               }
168               auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
169               if (resource_ptr != nullptr) {
170                 // StepParallel may replace the AbstractValue of the parameters of func_graph,
171                 // So generate the args_spec from parameters.
172                 abstract::AbstractBasePtrList maybe_new_args_spec;
173                 if (is_watch_renormalize_) {
174                   if (is_untyped_generated_) {
175                     std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
176                                    std::back_inserter(maybe_new_args_spec),
177                                    [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
178                     func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
179                     clear_is_untyped_generated();
180                   } else {
181                     MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False.";
182                   }
183                 } else {
184                   std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
185                                  std::back_inserter(maybe_new_args_spec),
186                                  [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
187                   func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
188                 }
189               }
190               changes_since_last_renorm = false;
191             } else if (opt(func_graph, shared_from_this())) {
192               changes = true;
193               changes_since_last_renorm = true;
194             }
195           };
196           use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
197 #ifdef ENABLE_DUMP_IR
198           static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
199           if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
200             auto fg_name =
201               "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
202             MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
203             DumpIR(fg_name + ".ir", func_graph);
204             if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
205               func_graph->DumpFuncGraph(fg_name);
206               ExportIR(fg_name + ".dat", func_graph);
207             }
208             MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
209           }
210 #endif
211         }
212       };
213       use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc();
214       counter++;
215 
216       if (run_only_once_) {
217         break;
218       }
219     }
220     return func_graph;
221   }
222 
resource()223   pipeline::ResourceBasePtr resource() const { return resource_; }
manager()224   FuncGraphManagerPtr manager() const {
225     if (resource_ != nullptr) {
226       return resource_->manager();
227     }
228     MS_LOG(EXCEPTION) << "No ResourceBase exists.";
229   }
230 
name()231   const std::string name() const { return name_; }
232 
set_is_untyped_generated()233   void set_is_untyped_generated() { is_untyped_generated_ = true; }
clear_is_untyped_generated()234   void clear_is_untyped_generated() { is_untyped_generated_ = false; }
235 
enable_watch_renormalize()236   void enable_watch_renormalize() { is_watch_renormalize_ = true; }
disable_watch_renormalize()237   void disable_watch_renormalize() { is_watch_renormalize_ = false; }
is_watch_renormalize()238   bool is_watch_renormalize() { return is_watch_renormalize_; }
set_enable(bool enable)239   void set_enable(bool enable) { is_enable_ = enable; }
240 
traverse_nodes_first()241   bool traverse_nodes_first() { return traverse_nodes_first_; }
242 
243   struct {
244     int64_t counter;
245     std::string name;
246   } CurPass_;
247 
248   bool is_on_debug_{false};
249 
250  private:
251   const std::string name_;
252   pipeline::ResourceBasePtr resource_;
253   std::vector<OptPass> passes_;
254   std::vector<std::string> pass_names_;
255   bool run_only_once_;
256   bool is_watch_renormalize_;
257   bool is_enable_;
258   bool is_untyped_generated_;
259   bool traverse_nodes_first_;
260 };
261 }  // namespace opt
262 }  // namespace mindspore
263 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_
264