1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/plugin_registry.h"
17
18 #include "tensorflow/stream_executor/lib/error.h"
19 #include "tensorflow/stream_executor/lib/stringprintf.h"
20 #include "tensorflow/stream_executor/multi_platform_manager.h"
21
22 namespace stream_executor {
23
24 const PluginId kNullPlugin = nullptr;
25
26 // Returns the string representation of the specified PluginKind.
PluginKindString(PluginKind plugin_kind)27 string PluginKindString(PluginKind plugin_kind) {
28 switch (plugin_kind) {
29 case PluginKind::kBlas:
30 return "BLAS";
31 case PluginKind::kDnn:
32 return "DNN";
33 case PluginKind::kFft:
34 return "FFT";
35 case PluginKind::kRng:
36 return "RNG";
37 case PluginKind::kInvalid:
38 default:
39 return "kInvalid";
40 }
41 }
42
DefaultFactories()43 PluginRegistry::DefaultFactories::DefaultFactories() :
44 blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
45
GetPluginRegistryMutex()46 static mutex& GetPluginRegistryMutex() {
47 static mutex* mu = new mutex;
48 return *mu;
49 }
50
51 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
52
PluginRegistry()53 PluginRegistry::PluginRegistry() {}
54
Instance()55 /* static */ PluginRegistry* PluginRegistry::Instance() {
56 mutex_lock lock{GetPluginRegistryMutex()};
57 if (instance_ == nullptr) {
58 instance_ = new PluginRegistry();
59 }
60 return instance_;
61 }
62
MapPlatformKindToId(PlatformKind platform_kind,Platform::Id platform_id)63 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
64 Platform::Id platform_id) {
65 platform_id_by_kind_[platform_kind] = platform_id;
66 }
67
68 template <typename FACTORY_TYPE>
RegisterFactoryInternal(PluginId plugin_id,const string & plugin_name,FACTORY_TYPE factory,std::map<PluginId,FACTORY_TYPE> * factories)69 port::Status PluginRegistry::RegisterFactoryInternal(
70 PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
71 std::map<PluginId, FACTORY_TYPE>* factories) {
72 mutex_lock lock{GetPluginRegistryMutex()};
73
74 if (factories->find(plugin_id) != factories->end()) {
75 return port::Status(
76 port::error::ALREADY_EXISTS,
77 port::Printf("Attempting to register factory for plugin %s when "
78 "one has already been registered",
79 plugin_name.c_str()));
80 }
81
82 (*factories)[plugin_id] = factory;
83 plugin_names_[plugin_id] = plugin_name;
84 return port::Status::OK();
85 }
86
87 template <typename FACTORY_TYPE>
GetFactoryInternal(PluginId plugin_id,const std::map<PluginId,FACTORY_TYPE> & factories,const std::map<PluginId,FACTORY_TYPE> & generic_factories) const88 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
89 PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
90 const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
91 auto iter = factories.find(plugin_id);
92 if (iter == factories.end()) {
93 iter = generic_factories.find(plugin_id);
94 if (iter == generic_factories.end()) {
95 return port::Status(
96 port::error::NOT_FOUND,
97 port::Printf("Plugin ID %p not registered.", plugin_id));
98 }
99 }
100
101 return iter->second;
102 }
103
SetDefaultFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id)104 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
105 PluginKind plugin_kind,
106 PluginId plugin_id) {
107 if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
108 port::StatusOr<Platform*> status =
109 MultiPlatformManager::PlatformWithId(platform_id);
110 string platform_name = "<unregistered platform>";
111 if (status.ok()) {
112 platform_name = status.ValueOrDie()->Name();
113 }
114
115 LOG(ERROR) << "A factory must be registered for a platform before being "
116 << "set as default! "
117 << "Platform name: " << platform_name
118 << ", PluginKind: " << PluginKindString(plugin_kind)
119 << ", PluginId: " << plugin_id;
120 return false;
121 }
122
123 switch (plugin_kind) {
124 case PluginKind::kBlas:
125 default_factories_[platform_id].blas = plugin_id;
126 break;
127 case PluginKind::kDnn:
128 default_factories_[platform_id].dnn = plugin_id;
129 break;
130 case PluginKind::kFft:
131 default_factories_[platform_id].fft = plugin_id;
132 break;
133 case PluginKind::kRng:
134 default_factories_[platform_id].rng = plugin_id;
135 break;
136 default:
137 LOG(ERROR) << "Invalid plugin kind specified: "
138 << static_cast<int>(plugin_kind);
139 return false;
140 }
141
142 return true;
143 }
144
HasFactory(const PluginFactories & factories,PluginKind plugin_kind,PluginId plugin_id) const145 bool PluginRegistry::HasFactory(const PluginFactories& factories,
146 PluginKind plugin_kind,
147 PluginId plugin_id) const {
148 switch (plugin_kind) {
149 case PluginKind::kBlas:
150 return factories.blas.find(plugin_id) != factories.blas.end();
151 case PluginKind::kDnn:
152 return factories.dnn.find(plugin_id) != factories.dnn.end();
153 case PluginKind::kFft:
154 return factories.fft.find(plugin_id) != factories.fft.end();
155 case PluginKind::kRng:
156 return factories.rng.find(plugin_id) != factories.rng.end();
157 default:
158 LOG(ERROR) << "Invalid plugin kind specified: "
159 << PluginKindString(plugin_kind);
160 return false;
161 }
162 }
163
HasFactory(Platform::Id platform_id,PluginKind plugin_kind,PluginId plugin_id) const164 bool PluginRegistry::HasFactory(Platform::Id platform_id,
165 PluginKind plugin_kind,
166 PluginId plugin_id) const {
167 auto iter = factories_.find(platform_id);
168 if (iter != factories_.end()) {
169 if (HasFactory(iter->second, plugin_kind, plugin_id)) {
170 return true;
171 }
172 }
173
174 return HasFactory(generic_factories_, plugin_kind, plugin_id);
175 }
176
177 // Explicit instantiations to support types exposed in user/public API.
178 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
179 template port::StatusOr<PluginRegistry::FACTORY_TYPE> \
180 PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
181 PluginId plugin_id, \
182 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \
183 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \
184 generic_factories) const; \
185 \
186 template port::Status \
187 PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
188 PluginId plugin_id, const string& plugin_name, \
189 PluginRegistry::FACTORY_TYPE factory, \
190 std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \
191 \
192 template <> \
193 port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
194 Platform::Id platform_id, PluginId plugin_id, const string& name, \
195 PluginRegistry::FACTORY_TYPE factory) { \
196 return RegisterFactoryInternal(plugin_id, name, factory, \
197 &factories_[platform_id].FACTORY_VAR); \
198 } \
199 \
200 template <> \
201 port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
202 PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \
203 PluginRegistry::FACTORY_TYPE factory) { \
204 return RegisterFactoryInternal(plugin_id, name, factory, \
205 &generic_factories_.FACTORY_VAR); \
206 } \
207 \
208 template <> \
209 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
210 Platform::Id platform_id, PluginId plugin_id) { \
211 if (plugin_id == PluginConfig::kDefault) { \
212 plugin_id = default_factories_[platform_id].FACTORY_VAR; \
213 \
214 if (plugin_id == kNullPlugin) { \
215 return port::Status( \
216 port::error::FAILED_PRECONDITION, \
217 "No suitable " PLUGIN_STRING \
218 " plugin registered. Have you linked in a " PLUGIN_STRING \
219 "-providing plugin?"); \
220 } else { \
221 VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
222 << plugin_names_[plugin_id]; \
223 } \
224 } \
225 return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
226 generic_factories_.FACTORY_VAR); \
227 } \
228 \
229 /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \
230 template <> \
231 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
232 PlatformKind platform_kind, PluginId plugin_id) { \
233 auto iter = platform_id_by_kind_.find(platform_kind); \
234 if (iter == platform_id_by_kind_.end()) { \
235 return port::Status(port::error::FAILED_PRECONDITION, \
236 port::Printf("Platform kind %d not registered.", \
237 static_cast<int>(platform_kind))); \
238 } \
239 return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
240 }
241
242 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
243 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
244 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
245 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
246
247 } // namespace stream_executor
248