• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/plugin_registry.h"
17 
18 #include "absl/base/const_init.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/stream_executor/lib/error.h"
23 #include "tensorflow/stream_executor/multi_platform_manager.h"
24 
25 namespace stream_executor {
26 
27 const PluginId kNullPlugin = nullptr;
28 
29 // Returns the string representation of the specified PluginKind.
PluginKindString(PluginKind plugin_kind)30 string PluginKindString(PluginKind plugin_kind) {
31   switch (plugin_kind) {
32     case PluginKind::kBlas:
33       return "BLAS";
34     case PluginKind::kDnn:
35       return "DNN";
36     case PluginKind::kFft:
37       return "FFT";
38     case PluginKind::kRng:
39       return "RNG";
40     case PluginKind::kInvalid:
41     default:
42       return "kInvalid";
43   }
44 }
45 
DefaultFactories()46 PluginRegistry::DefaultFactories::DefaultFactories() :
47     blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
48 
GetPluginRegistryMutex()49 static absl::Mutex& GetPluginRegistryMutex() {
50   static absl::Mutex mu(absl::kConstInit);
51   return mu;
52 }
53 
54 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
55 
PluginRegistry()56 PluginRegistry::PluginRegistry() {}
57 
Instance()58 /* static */ PluginRegistry* PluginRegistry::Instance() {
59   absl::MutexLock lock{&GetPluginRegistryMutex()};
60   if (instance_ == nullptr) {
61     instance_ = new PluginRegistry();
62   }
63   return instance_;
64 }
65 
MapPlatformKindToId(PlatformKind platform_kind,Platform::Id platform_id)66 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
67                                          Platform::Id platform_id) {
68   platform_id_by_kind_[platform_kind] = platform_id;
69 }
70 
71 template <typename FACTORY_TYPE>
RegisterFactoryInternal(PluginId plugin_id,const string & plugin_name,FACTORY_TYPE factory,std::map<PluginId,FACTORY_TYPE> * factories)72 port::Status PluginRegistry::RegisterFactoryInternal(
73     PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
74     std::map<PluginId, FACTORY_TYPE>* factories) {
75   absl::MutexLock lock{&GetPluginRegistryMutex()};
76 
77   if (factories->find(plugin_id) != factories->end()) {
78     return port::Status(
79         port::error::ALREADY_EXISTS,
80         absl::StrFormat("Attempting to register factory for plugin %s when "
81                         "one has already been registered",
82                         plugin_name));
83   }
84 
85   (*factories)[plugin_id] = factory;
86   plugin_names_[plugin_id] = plugin_name;
87   return port::Status::OK();
88 }
89 
90 template <typename FACTORY_TYPE>
GetFactoryInternal(PluginId plugin_id,const std::map<PluginId,FACTORY_TYPE> & factories,const std::map<PluginId,FACTORY_TYPE> & generic_factories) const91 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
92     PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
93     const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
94   auto iter = factories.find(plugin_id);
95   if (iter == factories.end()) {
96     iter = generic_factories.find(plugin_id);
97     if (iter == generic_factories.end()) {
98       return port::Status(
99           port::error::NOT_FOUND,
100           absl::StrFormat("Plugin ID %p not registered.", plugin_id));
101     }
102   }
103 
104   return iter->second;
105 }
106 
SetDefaultFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id)107 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
108                                        PluginKind plugin_kind,
109                                        PluginId plugin_id) {
110   if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
111     port::StatusOr<Platform*> status =
112         MultiPlatformManager::PlatformWithId(platform_id);
113     string platform_name = "<unregistered platform>";
114     if (status.ok()) {
115       platform_name = status.ValueOrDie()->Name();
116     }
117 
118     LOG(ERROR) << "A factory must be registered for a platform before being "
119                << "set as default! "
120                << "Platform name: " << platform_name
121                << ", PluginKind: " << PluginKindString(plugin_kind)
122                << ", PluginId: " << plugin_id;
123     return false;
124   }
125 
126   switch (plugin_kind) {
127     case PluginKind::kBlas:
128       default_factories_[platform_id].blas = plugin_id;
129       break;
130     case PluginKind::kDnn:
131       default_factories_[platform_id].dnn = plugin_id;
132       break;
133     case PluginKind::kFft:
134       default_factories_[platform_id].fft = plugin_id;
135       break;
136     case PluginKind::kRng:
137       default_factories_[platform_id].rng = plugin_id;
138       break;
139     default:
140       LOG(ERROR) << "Invalid plugin kind specified: "
141                  << static_cast<int>(plugin_kind);
142       return false;
143   }
144 
145   return true;
146 }
147 
HasFactory(const PluginFactories & factories,PluginKind plugin_kind,PluginId plugin_id) const148 bool PluginRegistry::HasFactory(const PluginFactories& factories,
149                                 PluginKind plugin_kind,
150                                 PluginId plugin_id) const {
151   switch (plugin_kind) {
152     case PluginKind::kBlas:
153       return factories.blas.find(plugin_id) != factories.blas.end();
154     case PluginKind::kDnn:
155       return factories.dnn.find(plugin_id) != factories.dnn.end();
156     case PluginKind::kFft:
157       return factories.fft.find(plugin_id) != factories.fft.end();
158     case PluginKind::kRng:
159       return factories.rng.find(plugin_id) != factories.rng.end();
160     default:
161       LOG(ERROR) << "Invalid plugin kind specified: "
162                  << PluginKindString(plugin_kind);
163       return false;
164   }
165 }
166 
HasFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id) const167 bool PluginRegistry::HasFactory(Platform::Id platform_id,
168                                 PluginKind plugin_kind,
169                                 PluginId plugin_id) const {
170   auto iter = factories_.find(platform_id);
171   if (iter != factories_.end()) {
172     if (HasFactory(iter->second, plugin_kind, plugin_id)) {
173       return true;
174     }
175   }
176 
177   return HasFactory(generic_factories_, plugin_kind, plugin_id);
178 }
179 
180 // Explicit instantiations to support types exposed in user/public API.
181 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
182   template port::StatusOr<PluginRegistry::FACTORY_TYPE>                       \
183   PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>(           \
184       PluginId plugin_id,                                                     \
185       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories,      \
186       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>&                 \
187           generic_factories) const;                                           \
188                                                                               \
189   template port::Status                                                       \
190   PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>(      \
191       PluginId plugin_id, const string& plugin_name,                          \
192       PluginRegistry::FACTORY_TYPE factory,                                   \
193       std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories);           \
194                                                                               \
195   template <>                                                                 \
196   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
197       Platform::Id platform_id, PluginId plugin_id, const string& name,       \
198       PluginRegistry::FACTORY_TYPE factory) {                                 \
199     return RegisterFactoryInternal(plugin_id, name, factory,                  \
200                                    &factories_[platform_id].FACTORY_VAR);     \
201   }                                                                           \
202                                                                               \
203   template <>                                                                 \
204   port::Status PluginRegistry::RegisterFactoryForAllPlatforms<                \
205       PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name,   \
206                                     PluginRegistry::FACTORY_TYPE factory) {   \
207     return RegisterFactoryInternal(plugin_id, name, factory,                  \
208                                    &generic_factories_.FACTORY_VAR);          \
209   }                                                                           \
210                                                                               \
211   template <>                                                                 \
212   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
213       Platform::Id platform_id, PluginId plugin_id) {                         \
214     if (plugin_id == PluginConfig::kDefault) {                                \
215       plugin_id = default_factories_[platform_id].FACTORY_VAR;                \
216                                                                               \
217       if (plugin_id == kNullPlugin) {                                         \
218         return port::Status(                                                  \
219             port::error::FAILED_PRECONDITION,                                 \
220             "No suitable " PLUGIN_STRING                                      \
221             " plugin registered. Have you linked in a " PLUGIN_STRING         \
222             "-providing plugin?");                                            \
223       } else {                                                                \
224         VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, "             \
225                 << plugin_names_[plugin_id];                                  \
226       }                                                                       \
227     }                                                                         \
228     return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
229                               generic_factories_.FACTORY_VAR);                \
230   }                                                                           \
231                                                                               \
232   /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */             \
233   template <>                                                                 \
234   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
235       PlatformKind platform_kind, PluginId plugin_id) {                       \
236     auto iter = platform_id_by_kind_.find(platform_kind);                     \
237     if (iter == platform_id_by_kind_.end()) {                                 \
238       return port::Status(port::error::FAILED_PRECONDITION,                   \
239                           absl::StrFormat("Platform kind %d not registered.", \
240                                           static_cast<int>(platform_kind)));  \
241     }                                                                         \
242     return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
243   }
244 
245 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
246 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
247 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
248 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
249 
250 }  // namespace stream_executor
251