1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/plugin_registry.h"
17
18 #include "absl/base/const_init.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/stream_executor/lib/error.h"
23 #include "tensorflow/stream_executor/multi_platform_manager.h"
24
25 namespace stream_executor {
26
27 const PluginId kNullPlugin = nullptr;
28
29 // Returns the string representation of the specified PluginKind.
PluginKindString(PluginKind plugin_kind)30 std::string PluginKindString(PluginKind plugin_kind) {
31 switch (plugin_kind) {
32 case PluginKind::kBlas:
33 return "BLAS";
34 case PluginKind::kDnn:
35 return "DNN";
36 case PluginKind::kFft:
37 return "FFT";
38 case PluginKind::kRng:
39 return "RNG";
40 case PluginKind::kInvalid:
41 default:
42 return "kInvalid";
43 }
44 }
45
DefaultFactories()46 PluginRegistry::DefaultFactories::DefaultFactories() :
47 blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
48
GetPluginRegistryMutex()49 static absl::Mutex& GetPluginRegistryMutex() {
50 static absl::Mutex mu(absl::kConstInit);
51 return mu;
52 }
53
54 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
55
PluginRegistry()56 PluginRegistry::PluginRegistry() {}
57
Instance()58 /* static */ PluginRegistry* PluginRegistry::Instance() {
59 absl::MutexLock lock{&GetPluginRegistryMutex()};
60 if (instance_ == nullptr) {
61 instance_ = new PluginRegistry();
62 }
63 return instance_;
64 }
65
MapPlatformKindToId(PlatformKind platform_kind,Platform::Id platform_id)66 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
67 Platform::Id platform_id) {
68 platform_id_by_kind_[platform_kind] = platform_id;
69 }
70
71 template <typename FACTORY_TYPE>
RegisterFactoryInternal(PluginId plugin_id,const std::string & plugin_name,FACTORY_TYPE factory,std::map<PluginId,FACTORY_TYPE> * factories)72 port::Status PluginRegistry::RegisterFactoryInternal(
73 PluginId plugin_id, const std::string& plugin_name, FACTORY_TYPE factory,
74 std::map<PluginId, FACTORY_TYPE>* factories) {
75 absl::MutexLock lock{&GetPluginRegistryMutex()};
76
77 if (factories->find(plugin_id) != factories->end()) {
78 return port::Status(
79 port::error::ALREADY_EXISTS,
80 absl::StrFormat("Attempting to register factory for plugin %s when "
81 "one has already been registered",
82 plugin_name));
83 }
84
85 (*factories)[plugin_id] = factory;
86 plugin_names_[plugin_id] = plugin_name;
87 return port::Status::OK();
88 }
89
90 template <typename FACTORY_TYPE>
GetFactoryInternal(PluginId plugin_id,const std::map<PluginId,FACTORY_TYPE> & factories,const std::map<PluginId,FACTORY_TYPE> & generic_factories) const91 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
92 PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
93 const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
94 auto iter = factories.find(plugin_id);
95 if (iter == factories.end()) {
96 iter = generic_factories.find(plugin_id);
97 if (iter == generic_factories.end()) {
98 return port::Status(
99 port::error::NOT_FOUND,
100 absl::StrFormat("Plugin ID %p not registered.", plugin_id));
101 }
102 }
103
104 return iter->second;
105 }
106
SetDefaultFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id)107 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
108 PluginKind plugin_kind,
109 PluginId plugin_id) {
110 if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
111 port::StatusOr<Platform*> status =
112 MultiPlatformManager::PlatformWithId(platform_id);
113 std::string platform_name = "<unregistered platform>";
114 if (status.ok()) {
115 platform_name = status.ValueOrDie()->Name();
116 }
117
118 LOG(ERROR) << "A factory must be registered for a platform before being "
119 << "set as default! "
120 << "Platform name: " << platform_name
121 << ", PluginKind: " << PluginKindString(plugin_kind)
122 << ", PluginId: " << plugin_id;
123 return false;
124 }
125
126 switch (plugin_kind) {
127 case PluginKind::kBlas:
128 default_factories_[platform_id].blas = plugin_id;
129 break;
130 case PluginKind::kDnn:
131 default_factories_[platform_id].dnn = plugin_id;
132 break;
133 case PluginKind::kFft:
134 default_factories_[platform_id].fft = plugin_id;
135 break;
136 case PluginKind::kRng:
137 default_factories_[platform_id].rng = plugin_id;
138 break;
139 default:
140 LOG(ERROR) << "Invalid plugin kind specified: "
141 << static_cast<int>(plugin_kind);
142 return false;
143 }
144
145 return true;
146 }
147
HasFactory(const PluginFactories & factories,PluginKind plugin_kind,PluginId plugin_id) const148 bool PluginRegistry::HasFactory(const PluginFactories& factories,
149 PluginKind plugin_kind,
150 PluginId plugin_id) const {
151 switch (plugin_kind) {
152 case PluginKind::kBlas:
153 return factories.blas.find(plugin_id) != factories.blas.end();
154 case PluginKind::kDnn:
155 return factories.dnn.find(plugin_id) != factories.dnn.end();
156 case PluginKind::kFft:
157 return factories.fft.find(plugin_id) != factories.fft.end();
158 case PluginKind::kRng:
159 return factories.rng.find(plugin_id) != factories.rng.end();
160 default:
161 LOG(ERROR) << "Invalid plugin kind specified: "
162 << PluginKindString(plugin_kind);
163 return false;
164 }
165 }
166
HasFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id) const167 bool PluginRegistry::HasFactory(Platform::Id platform_id,
168 PluginKind plugin_kind,
169 PluginId plugin_id) const {
170 auto iter = factories_.find(platform_id);
171 if (iter != factories_.end()) {
172 if (HasFactory(iter->second, plugin_kind, plugin_id)) {
173 return true;
174 }
175 }
176
177 return HasFactory(generic_factories_, plugin_kind, plugin_id);
178 }
179
180 // Explicit instantiations to support types exposed in user/public API.
181 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
182 template port::StatusOr<PluginRegistry::FACTORY_TYPE> \
183 PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
184 PluginId plugin_id, \
185 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \
186 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \
187 generic_factories) const; \
188 \
189 template port::Status \
190 PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
191 PluginId plugin_id, const std::string& plugin_name, \
192 PluginRegistry::FACTORY_TYPE factory, \
193 std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \
194 \
195 template <> \
196 port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
197 Platform::Id platform_id, PluginId plugin_id, const std::string& name, \
198 PluginRegistry::FACTORY_TYPE factory) { \
199 return RegisterFactoryInternal(plugin_id, name, factory, \
200 &factories_[platform_id].FACTORY_VAR); \
201 } \
202 \
203 template <> \
204 port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
205 PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, \
206 const std::string& name, \
207 PluginRegistry::FACTORY_TYPE factory) { \
208 return RegisterFactoryInternal(plugin_id, name, factory, \
209 &generic_factories_.FACTORY_VAR); \
210 } \
211 \
212 template <> \
213 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
214 Platform::Id platform_id, PluginId plugin_id) { \
215 if (plugin_id == PluginConfig::kDefault) { \
216 plugin_id = default_factories_[platform_id].FACTORY_VAR; \
217 \
218 if (plugin_id == kNullPlugin) { \
219 return port::Status( \
220 port::error::FAILED_PRECONDITION, \
221 "No suitable " PLUGIN_STRING \
222 " plugin registered. Have you linked in a " PLUGIN_STRING \
223 "-providing plugin?"); \
224 } else { \
225 VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
226 << plugin_names_[plugin_id]; \
227 } \
228 } \
229 return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
230 generic_factories_.FACTORY_VAR); \
231 } \
232 \
233 /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \
234 template <> \
235 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
236 PlatformKind platform_kind, PluginId plugin_id) { \
237 auto iter = platform_id_by_kind_.find(platform_kind); \
238 if (iter == platform_id_by_kind_.end()) { \
239 return port::Status(port::error::FAILED_PRECONDITION, \
240 absl::StrFormat("Platform kind %d not registered.", \
241 static_cast<int>(platform_kind))); \
242 } \
243 return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
244 }
245
246 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
247 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
248 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
249 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
250
251 } // namespace stream_executor
252