1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/registry/pass_registry.h"
18 #include <map>
19 #include <mutex>
20 #include <string>
21 #include <vector>
22 #include "src/common/log_adapter.h"
23 #include "nnacl/op_base.h"
24
25 namespace mindspore {
26 namespace registry {
27 namespace {
28 constexpr size_t kPassNumLimit = 10000;
29 std::map<std::string, PassBasePtr> outer_pass_storage;
30 std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
31 std::mutex pass_mutex;
RegPass(const std::string & pass_name,const PassBasePtr & pass)32 void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
33 if (pass == nullptr) {
34 MS_LOG(ERROR) << "pass is nullptr.";
35 return;
36 }
37 std::unique_lock<std::mutex> lock(pass_mutex);
38 if (outer_pass_storage.size() == kPassNumLimit) {
39 MS_LOG(WARNING) << "pass's number is up to the limitation. The pass will not be registered.";
40 return;
41 }
42 outer_pass_storage[pass_name] = pass;
43 }
44 } // namespace
45
PassRegistry(const std::vector<char> & pass_name,const PassBasePtr & pass)46 PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass) {
47 RegPass(CharToString(pass_name), pass);
48 }
49
PassRegistry(PassPosition position,const std::vector<std::vector<char>> & names)50 PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
51 if (position < POSITION_BEGIN || position > POSITION_END) {
52 MS_LOG(ERROR) << "ILLEGAL position: position must be POSITION_BEGIN or POSITION_END.";
53 return;
54 }
55 std::unique_lock<std::mutex> lock(pass_mutex);
56 external_assigned_passes[position] = VectorCharToString(names);
57 }
58
GetOuterScheduleTaskInner(PassPosition position)59 std::vector<std::vector<char>> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) {
60 MS_CHECK_TRUE_MSG(position == POSITION_END || position == POSITION_BEGIN, {},
61 "position must be POSITION_END or POSITION_BEGIN.");
62 return VectorStringToChar(external_assigned_passes[position]);
63 }
64
GetPassFromStoreRoom(const std::vector<char> & pass_name_char)65 PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector<char> &pass_name_char) {
66 std::string pass_name = CharToString(pass_name_char);
67 return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
68 }
69 } // namespace registry
70 } // namespace mindspore
71