1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ 19 20 #include <algorithm> 21 #include <functional> 22 #include <iterator> 23 #include <memory> 24 #include <string> 25 #include <vector> 26 #include <map> 27 #include <utility> 28 #include <initializer_list> 29 30 #include "debug/draw.h" 31 #include "debug/anf_ir_dump.h" 32 #include "debug/anf_ir_utils.h" 33 #include "debug/trace.h" 34 #include "frontend/optimizer/opt.h" 35 #include "pipeline/jit/resource.h" 36 #include "pipeline/jit/action.h" 37 #include "utils/ms_context.h" 38 39 namespace mindspore { 40 namespace opt { 41 using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer)>; 42 43 class OptPassConfig { 44 public: OptPassConfig(const OptimizeGraphFunc & func)45 explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} 46 explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) list_(list)47 : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} 48 OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false) list_(list)49 : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {} 50 ~OptPassConfig() = default; 51 list()52 const std::vector<SubstitutionPtr> &list() const { return list_; } func()53 const OptimizeGraphFunc &func() const { return func_; } 54 Renormalize()55 static OptPassConfig Renormalize() { return OptPassConfig(); } is_renormalize()56 const bool is_renormalize() const { return is_renormalize_; } 57 is_once()58 const bool is_once() const { return is_once_; } 59 global_sensitive()60 const bool global_sensitive() const { return global_sensitive_; } 61 62 private: OptPassConfig()63 OptPassConfig() : is_renormalize_(true) {} 64 65 OptimizeGraphFunc func_; 66 std::vector<SubstitutionPtr> list_; 67 bool is_renormalize_{false}; 68 bool is_once_{false}; 69 bool global_sensitive_{false}; 70 }; 71 72 class OptPass { 73 public: OptPass(const OptimizeGraphFunc & func)74 explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {} 75 ~OptPass() = default; 76 operator()77 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { 78 return pass_func_(func_graph, optimizer); 79 } 80 Renormalize()81 static OptPass Renormalize() { return OptPass(); } is_renormalize()82 const bool is_renormalize() const { return is_renormalize_; } 83 84 private: OptPass()85 OptPass() : is_renormalize_(true) {} 86 87 OptimizeGraphFunc pass_func_; 88 bool is_renormalize_{false}; 89 }; 90 using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; 91 92 class Optimizer : public std::enable_shared_from_this<Optimizer> { 93 public: 94 Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true) name_(name)95 : name_(name), 96 resource_(resource_ptr), 97 run_only_once_(false), 98 is_watch_renormalize_(false), 99 is_enable_(true), 100 is_untyped_generated_(false), 101 traverse_nodes_first_(traverse_nodes_first) {} 102 virtual ~Optimizer() = default; 103 Init(const OptPassGroupMap & passes,bool run_only_once)104 void Init(const OptPassGroupMap &passes, bool run_only_once) { 105 run_only_once_ = run_only_once; 106 is_watch_renormalize_ = false; 107 is_untyped_generated_ = false; 108 is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); 109 110 for (auto &iter : passes) { 111 const std::string &name = iter.first; 112 pass_names_.push_back(name); 113 114 const OptPassConfig &config = iter.second; 115 if (config.is_renormalize()) { 116 passes_.push_back(OptPass::Renormalize()); 117 continue; 118 } 119 120 if (config.list().size() > 0) { 121 OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive()); 122 passes_.push_back(OptPass(func)); 123 continue; 124 } 125 126 passes_.push_back(OptPass(config.func())); 127 } 128 129 if (passes_.size() == 1) { 130 run_only_once_ = true; 131 } 132 } 133 134 static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, 135 const OptPassGroupMap &passes, bool run_only_once = false, 136 bool watch_renormalize = false, bool traverse_nodes_first = true) { 137 OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first); 138 optimizer->Init(passes, run_only_once); 139 if (watch_renormalize) { 140 optimizer->enable_watch_renormalize(); 141 } 142 return optimizer; 143 } 144 145 FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { 146 if (!is_enable_) { 147 return func_graph; 148 } 149 // Optimizer step counter; 150 int64_t counter = 1; 151 bool changes = true; 152 // If no changes since last renormalization, then no need to do the renormalization again. 153 // Set the initial value to true, so the renormalization can be executed once if it's the 154 // only pass. 155 bool changes_since_last_renorm = true; 156 157 while (changes) { 158 changes = false; 159 auto run_runc = [&counter, &func_graph, &changes, &changes_since_last_renorm, use_profile, this]() { 160 for (size_t i = 0; i < passes_.size(); ++i) { 161 const OptPass &opt = passes_[i]; 162 CurPass_ = {counter, pass_names_[i]}; 163 auto opt_func = [&func_graph, &changes, &opt, &changes_since_last_renorm, this]() { 164 if (opt.is_renormalize()) { 165 if (!changes_since_last_renorm) { 166 return; 167 } 168 auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); 169 if (resource_ptr != nullptr) { 170 // StepParallel may replace the AbstractValue of the parameters of func_graph, 171 // So generate the args_spec from parameters. 172 abstract::AbstractBasePtrList maybe_new_args_spec; 173 if (is_watch_renormalize_) { 174 if (is_untyped_generated_) { 175 std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), 176 std::back_inserter(maybe_new_args_spec), 177 [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); 178 func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); 179 clear_is_untyped_generated(); 180 } else { 181 MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; 182 } 183 } else { 184 std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), 185 std::back_inserter(maybe_new_args_spec), 186 [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); 187 func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); 188 } 189 } 190 changes_since_last_renorm = false; 191 } else if (opt(func_graph, shared_from_this())) { 192 changes = true; 193 changes_since_last_renorm = true; 194 } 195 }; 196 use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); 197 #ifdef ENABLE_DUMP_IR 198 static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); 199 if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { 200 auto fg_name = 201 "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; 202 MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; 203 DumpIR(fg_name + ".ir", func_graph); 204 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { 205 func_graph->DumpFuncGraph(fg_name); 206 ExportIR(fg_name + ".dat", func_graph); 207 } 208 MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; 209 } 210 #endif 211 } 212 }; 213 use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc(); 214 counter++; 215 216 if (run_only_once_) { 217 break; 218 } 219 } 220 return func_graph; 221 } 222 resource()223 pipeline::ResourceBasePtr resource() const { return resource_; } manager()224 FuncGraphManagerPtr manager() const { 225 if (resource_ != nullptr) { 226 return resource_->manager(); 227 } 228 MS_LOG(EXCEPTION) << "No ResourceBase exists."; 229 } 230 name()231 const std::string name() const { return name_; } 232 set_is_untyped_generated()233 void set_is_untyped_generated() { is_untyped_generated_ = true; } clear_is_untyped_generated()234 void clear_is_untyped_generated() { is_untyped_generated_ = false; } 235 enable_watch_renormalize()236 void enable_watch_renormalize() { is_watch_renormalize_ = true; } disable_watch_renormalize()237 void disable_watch_renormalize() { is_watch_renormalize_ = false; } is_watch_renormalize()238 bool is_watch_renormalize() { return is_watch_renormalize_; } set_enable(bool enable)239 void set_enable(bool enable) { is_enable_ = enable; } 240 traverse_nodes_first()241 bool traverse_nodes_first() { return traverse_nodes_first_; } 242 243 struct { 244 int64_t counter; 245 std::string name; 246 } CurPass_; 247 248 bool is_on_debug_{false}; 249 250 private: 251 const std::string name_; 252 pipeline::ResourceBasePtr resource_; 253 std::vector<OptPass> passes_; 254 std::vector<std::string> pass_names_; 255 bool run_only_once_; 256 bool is_watch_renormalize_; 257 bool is_enable_; 258 bool is_untyped_generated_; 259 bool traverse_nodes_first_; 260 }; 261 } // namespace opt 262 } // namespace mindspore 263 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ 264