• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/converter/optimizer_manager.h"
19 #include <map>
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include "include/backend/optimizer/pass.h"
24 #include "src/common/log_util.h"
25 #include "tools/converter/parser/parser_utils.h"
26 #include "include/registry/pass_base.h"
27 #include "nnacl/op_base.h"
28 
29 namespace mindspore {
30 namespace lite {
31 std::map<std::string, opt::PassPtr> PassStorage::pass_storage_;
32 std::set<std::string> PassStorage::inaccessible_for_outer_;
RunOptimizerPass(const FuncGraphPtr & func_graph,const std::vector<std::string> & pass_names)33 bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
34   if (func_graph == nullptr) {
35     MS_LOG(ERROR) << "func graph is nullptr.";
36     return false;
37   }
38   auto manager = func_graph->manager();
39   if (manager == nullptr) {
40     manager = Manage(func_graph, true);
41     MS_CHECK_TRUE_RET(manager != nullptr, false);
42     std::set<FuncGraphPtr> all_func_graphs;
43     GetAllFuncGraph(func_graph, &all_func_graphs);
44     for (auto &graph : all_func_graphs) {
45       manager->AddFuncGraph(graph);
46     }
47   }
48   for (auto &pass_name : pass_names) {
49     auto pass_outer = registry::PassRegistry::GetPassFromStoreRoom(pass_name);
50     if (pass_outer != nullptr) {
51       auto api_graph = api::MakeShared<api::FuncGraph>(func_graph);
52       MS_CHECK_TRUE_RET(api_graph != nullptr, false);
53       if (!pass_outer->Execute(api_graph)) {
54         MS_LOG(INFO) << "Execute this pass without modifying the graph, pass name: " << pass_name;
55       }
56       continue;
57     }
58     auto pass_builtin = PassStorage::GetPassFromStorage(pass_name);
59     if (pass_builtin == nullptr) {
60       MS_LOG(ERROR) << "exited pass cannot be obtained, pass name is " << pass_name;
61       return false;
62     }
63     if (!pass_builtin->Run(func_graph)) {
64       MS_LOG(INFO) << "Execute this pass without modifying the graph, pass name: " << pass_name;
65     }
66   }
67   return true;
68 }
69 
RunExternalPass(const FuncGraphPtr & func_graph,registry::PassPosition position)70 bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position) {
71   if (func_graph == nullptr) {
72     MS_LOG(ERROR) << "func graph is nullptr.";
73     return false;
74   }
75   auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
76   for (const auto &pass_name : schedule_task) {
77     if (!PassStorage::IsAccessibleForOuter(pass_name)) {
78       MS_LOG(ERROR) << pass_name << " is an inaccessible pass for outer calling.";
79       return false;
80     }
81   }
82   if (!RunOptimizerPass(func_graph, schedule_task)) {
83     MS_LOG(WARNING) << "run external scheduled task failed.";
84     return false;
85   }
86   return true;
87 }
88 }  // namespace lite
89 }  // namespace mindspore
90