• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/plugin_registry.h"
17 
18 #include "absl/base/const_init.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/stream_executor/lib/error.h"
23 #include "tensorflow/stream_executor/multi_platform_manager.h"
24 
25 namespace stream_executor {
26 
27 const PluginId kNullPlugin = nullptr;
28 
29 // Returns the string representation of the specified PluginKind.
PluginKindString(PluginKind plugin_kind)30 std::string PluginKindString(PluginKind plugin_kind) {
31   switch (plugin_kind) {
32     case PluginKind::kBlas:
33       return "BLAS";
34     case PluginKind::kDnn:
35       return "DNN";
36     case PluginKind::kFft:
37       return "FFT";
38     case PluginKind::kRng:
39       return "RNG";
40     case PluginKind::kInvalid:
41     default:
42       return "kInvalid";
43   }
44 }
45 
DefaultFactories()46 PluginRegistry::DefaultFactories::DefaultFactories() :
47     blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
48 
GetPluginRegistryMutex()49 static absl::Mutex& GetPluginRegistryMutex() {
50   static absl::Mutex mu(absl::kConstInit);
51   return mu;
52 }
53 
54 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
55 
PluginRegistry()56 PluginRegistry::PluginRegistry() {}
57 
Instance()58 /* static */ PluginRegistry* PluginRegistry::Instance() {
59   absl::MutexLock lock{&GetPluginRegistryMutex()};
60   if (instance_ == nullptr) {
61     instance_ = new PluginRegistry();
62   }
63   return instance_;
64 }
65 
MapPlatformKindToId(PlatformKind platform_kind,Platform::Id platform_id)66 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
67                                          Platform::Id platform_id) {
68   platform_id_by_kind_[platform_kind] = platform_id;
69 }
70 
71 template <typename FACTORY_TYPE>
RegisterFactoryInternal(PluginId plugin_id,const std::string & plugin_name,FACTORY_TYPE factory,std::map<PluginId,FACTORY_TYPE> * factories)72 port::Status PluginRegistry::RegisterFactoryInternal(
73     PluginId plugin_id, const std::string& plugin_name, FACTORY_TYPE factory,
74     std::map<PluginId, FACTORY_TYPE>* factories) {
75   absl::MutexLock lock{&GetPluginRegistryMutex()};
76 
77   if (factories->find(plugin_id) != factories->end()) {
78     return port::Status(
79         port::error::ALREADY_EXISTS,
80         absl::StrFormat("Attempting to register factory for plugin %s when "
81                         "one has already been registered",
82                         plugin_name));
83   }
84 
85   (*factories)[plugin_id] = factory;
86   plugin_names_[plugin_id] = plugin_name;
87   return port::Status::OK();
88 }
89 
90 template <typename FACTORY_TYPE>
GetFactoryInternal(PluginId plugin_id,const std::map<PluginId,FACTORY_TYPE> & factories,const std::map<PluginId,FACTORY_TYPE> & generic_factories) const91 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
92     PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
93     const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
94   auto iter = factories.find(plugin_id);
95   if (iter == factories.end()) {
96     iter = generic_factories.find(plugin_id);
97     if (iter == generic_factories.end()) {
98       return port::Status(
99           port::error::NOT_FOUND,
100           absl::StrFormat("Plugin ID %p not registered.", plugin_id));
101     }
102   }
103 
104   return iter->second;
105 }
106 
SetDefaultFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id)107 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
108                                        PluginKind plugin_kind,
109                                        PluginId plugin_id) {
110   if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
111     port::StatusOr<Platform*> status =
112         MultiPlatformManager::PlatformWithId(platform_id);
113     std::string platform_name = "<unregistered platform>";
114     if (status.ok()) {
115       platform_name = status.ValueOrDie()->Name();
116     }
117 
118     LOG(ERROR) << "A factory must be registered for a platform before being "
119                << "set as default! "
120                << "Platform name: " << platform_name
121                << ", PluginKind: " << PluginKindString(plugin_kind)
122                << ", PluginId: " << plugin_id;
123     return false;
124   }
125 
126   switch (plugin_kind) {
127     case PluginKind::kBlas:
128       default_factories_[platform_id].blas = plugin_id;
129       break;
130     case PluginKind::kDnn:
131       default_factories_[platform_id].dnn = plugin_id;
132       break;
133     case PluginKind::kFft:
134       default_factories_[platform_id].fft = plugin_id;
135       break;
136     case PluginKind::kRng:
137       default_factories_[platform_id].rng = plugin_id;
138       break;
139     default:
140       LOG(ERROR) << "Invalid plugin kind specified: "
141                  << static_cast<int>(plugin_kind);
142       return false;
143   }
144 
145   return true;
146 }
147 
HasFactory(const PluginFactories & factories,PluginKind plugin_kind,PluginId plugin_id) const148 bool PluginRegistry::HasFactory(const PluginFactories& factories,
149                                 PluginKind plugin_kind,
150                                 PluginId plugin_id) const {
151   switch (plugin_kind) {
152     case PluginKind::kBlas:
153       return factories.blas.find(plugin_id) != factories.blas.end();
154     case PluginKind::kDnn:
155       return factories.dnn.find(plugin_id) != factories.dnn.end();
156     case PluginKind::kFft:
157       return factories.fft.find(plugin_id) != factories.fft.end();
158     case PluginKind::kRng:
159       return factories.rng.find(plugin_id) != factories.rng.end();
160     default:
161       LOG(ERROR) << "Invalid plugin kind specified: "
162                  << PluginKindString(plugin_kind);
163       return false;
164   }
165 }
166 
HasFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id) const167 bool PluginRegistry::HasFactory(Platform::Id platform_id,
168                                 PluginKind plugin_kind,
169                                 PluginId plugin_id) const {
170   auto iter = factories_.find(platform_id);
171   if (iter != factories_.end()) {
172     if (HasFactory(iter->second, plugin_kind, plugin_id)) {
173       return true;
174     }
175   }
176 
177   return HasFactory(generic_factories_, plugin_kind, plugin_id);
178 }
179 
180 // Explicit instantiations to support types exposed in user/public API.
181 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
182   template port::StatusOr<PluginRegistry::FACTORY_TYPE>                       \
183   PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>(           \
184       PluginId plugin_id,                                                     \
185       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories,      \
186       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>&                 \
187           generic_factories) const;                                           \
188                                                                               \
189   template port::Status                                                       \
190   PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>(      \
191       PluginId plugin_id, const std::string& plugin_name,                     \
192       PluginRegistry::FACTORY_TYPE factory,                                   \
193       std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories);           \
194                                                                               \
195   template <>                                                                 \
196   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
197       Platform::Id platform_id, PluginId plugin_id, const std::string& name,  \
198       PluginRegistry::FACTORY_TYPE factory) {                                 \
199     return RegisterFactoryInternal(plugin_id, name, factory,                  \
200                                    &factories_[platform_id].FACTORY_VAR);     \
201   }                                                                           \
202                                                                               \
203   template <>                                                                 \
204   port::Status PluginRegistry::RegisterFactoryForAllPlatforms<                \
205       PluginRegistry::FACTORY_TYPE>(PluginId plugin_id,                       \
206                                     const std::string& name,                  \
207                                     PluginRegistry::FACTORY_TYPE factory) {   \
208     return RegisterFactoryInternal(plugin_id, name, factory,                  \
209                                    &generic_factories_.FACTORY_VAR);          \
210   }                                                                           \
211                                                                               \
212   template <>                                                                 \
213   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
214       Platform::Id platform_id, PluginId plugin_id) {                         \
215     if (plugin_id == PluginConfig::kDefault) {                                \
216       plugin_id = default_factories_[platform_id].FACTORY_VAR;                \
217                                                                               \
218       if (plugin_id == kNullPlugin) {                                         \
219         return port::Status(                                                  \
220             port::error::FAILED_PRECONDITION,                                 \
221             "No suitable " PLUGIN_STRING                                      \
222             " plugin registered. Have you linked in a " PLUGIN_STRING         \
223             "-providing plugin?");                                            \
224       } else {                                                                \
225         VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, "             \
226                 << plugin_names_[plugin_id];                                  \
227       }                                                                       \
228     }                                                                         \
229     return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
230                               generic_factories_.FACTORY_VAR);                \
231   }                                                                           \
232                                                                               \
233   /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */             \
234   template <>                                                                 \
235   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
236       PlatformKind platform_kind, PluginId plugin_id) {                       \
237     auto iter = platform_id_by_kind_.find(platform_kind);                     \
238     if (iter == platform_id_by_kind_.end()) {                                 \
239       return port::Status(port::error::FAILED_PRECONDITION,                   \
240                           absl::StrFormat("Platform kind %d not registered.", \
241                                           static_cast<int>(platform_kind)));  \
242     }                                                                         \
243     return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
244   }
245 
246 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
247 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
248 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
249 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
250 
251 }  // namespace stream_executor
252