• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/registry/pass_registry.h"
18 #include <map>
19 #include <mutex>
20 #include <string>
21 #include <vector>
22 #include "src/common/log_adapter.h"
23 #include "nnacl/op_base.h"
24 
25 namespace mindspore {
26 namespace registry {
27 namespace {
28 constexpr size_t kPassNumLimit = 10000;
29 std::map<std::string, PassBasePtr> outer_pass_storage;
30 std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
31 std::mutex pass_mutex;
RegPass(const std::string & pass_name,const PassBasePtr & pass)32 void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
33   if (pass == nullptr) {
34     MS_LOG(ERROR) << "pass is nullptr.";
35     return;
36   }
37   std::unique_lock<std::mutex> lock(pass_mutex);
38   if (outer_pass_storage.size() == kPassNumLimit) {
39     MS_LOG(WARNING) << "pass's number is up to the limitation. The pass will not be registered.";
40     return;
41   }
42   outer_pass_storage[pass_name] = pass;
43 }
44 }  // namespace
45 
PassRegistry(const std::vector<char> & pass_name,const PassBasePtr & pass)46 PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass) {
47   RegPass(CharToString(pass_name), pass);
48 }
49 
PassRegistry(PassPosition position,const std::vector<std::vector<char>> & names)50 PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
51   if (position < POSITION_BEGIN || position > POSITION_END) {
52     MS_LOG(ERROR) << "ILLEGAL position: position must be POSITION_BEGIN or POSITION_END.";
53     return;
54   }
55   std::unique_lock<std::mutex> lock(pass_mutex);
56   external_assigned_passes[position] = VectorCharToString(names);
57 }
58 
GetOuterScheduleTaskInner(PassPosition position)59 std::vector<std::vector<char>> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) {
60   MS_CHECK_TRUE_MSG(position == POSITION_END || position == POSITION_BEGIN, {},
61                     "position must be POSITION_END or POSITION_BEGIN.");
62   return VectorStringToChar(external_assigned_passes[position]);
63 }
64 
GetPassFromStoreRoom(const std::vector<char> & pass_name_char)65 PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector<char> &pass_name_char) {
66   std::string pass_name = CharToString(pass_name_char);
67   return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
68 }
69 }  // namespace registry
70 }  // namespace mindspore
71