1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/multi_platform_manager.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/ascii.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/synchronization/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/stream_executor/lib/error.h"
26 #include "tensorflow/stream_executor/lib/initialize.h"
27
28 namespace stream_executor {
29 namespace {
30
31 class MultiPlatformManagerImpl {
32 public:
33 port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
34 TF_LOCKS_EXCLUDED(mu_);
35
36 port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
37 TF_LOCKS_EXCLUDED(mu_);
38
39 port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
40 TF_LOCKS_EXCLUDED(mu_);
41
42 port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
43 bool initialize_platform)
44 TF_LOCKS_EXCLUDED(mu_);
45
46 port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
47 bool initialize_platform)
48 TF_LOCKS_EXCLUDED(mu_);
49
50 port::StatusOr<Platform*> InitializePlatformWithName(
51 absl::string_view target,
52 const std::map<std::string, std::string>& options) TF_LOCKS_EXCLUDED(mu_);
53 port::StatusOr<Platform*> InitializePlatformWithId(
54 const Platform::Id& id, const std::map<std::string, std::string>& options)
55 TF_LOCKS_EXCLUDED(mu_);
56
57 port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
58 const std::function<bool(const Platform*)>& filter,
59 bool initialize_platform) TF_LOCKS_EXCLUDED(mu_);
60
61 using Listener = MultiPlatformManager::Listener;
62 port::Status RegisterListener(std::unique_ptr<Listener> listener)
63 TF_LOCKS_EXCLUDED(mu_);
64
65 private:
66 // Looks up the platform object with the given name. Assumes the Platforms
67 // mutex is held.
68 port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
69 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
70
71 // Looks up the platform object with the given id. Assumes the Platforms
72 // mutex is held.
73 port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
74 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
75
76 // Returns the names of the initialied platforms satisfying the given filter.
77 // By default, it will return all initialized platform names.
78 std::vector<std::string> InitializedPlatformNamesWithFilter(
__anon233c95820202(const Platform*) 79 const std::function<bool(const Platform*)>& filter = [](const Platform*) {
80 return true;
81 }) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
82
83 absl::Mutex mu_;
84 std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
85 absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
86 absl::flat_hash_map<std::string, Platform*> name_map_ TF_GUARDED_BY(mu_);
87 };
88
RegisterPlatform(std::unique_ptr<Platform> platform)89 port::Status MultiPlatformManagerImpl::RegisterPlatform(
90 std::unique_ptr<Platform> platform) {
91 CHECK(platform != nullptr);
92 std::string key = absl::AsciiStrToLower(platform->Name());
93 absl::MutexLock lock(&mu_);
94 if (name_map_.find(key) != name_map_.end()) {
95 return port::Status(port::error::INTERNAL,
96 "platform is already registered with name: \"" +
97 platform->Name() + "\"");
98 }
99 Platform* platform_ptr = platform.get();
100 CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
101 // Release ownership/uniqueness to prevent destruction on program exit.
102 // This avoids Platforms "cleaning up" on program exit, because otherwise,
103 // there are _very_ tricky races between StreamExecutor and underlying
104 // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
105 // program, these are deemed acceptable.
106 name_map_[key] = platform.release();
107 for (const auto& listener : listeners_) {
108 listener->PlatformRegistered(platform_ptr);
109 }
110 return port::Status::OK();
111 }
112
PlatformWithName(absl::string_view target)113 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
114 absl::string_view target) {
115 return PlatformWithName(target, /*initialize_platform=*/true);
116 }
117
PlatformWithId(const Platform::Id & id)118 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
119 const Platform::Id& id) {
120 return PlatformWithId(id, /*initialize_platform=*/true);
121 }
122
PlatformWithName(absl::string_view target,bool initialize_platform)123 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
124 absl::string_view target, bool initialize_platform) {
125 absl::MutexLock lock(&mu_);
126
127 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
128 if (initialize_platform && !platform->Initialized()) {
129 SE_RETURN_IF_ERROR(platform->Initialize({}));
130 }
131
132 return platform;
133 }
134
PlatformWithId(const Platform::Id & id,bool initialize_platform)135 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
136 const Platform::Id& id, bool initialize_platform) {
137 absl::MutexLock lock(&mu_);
138
139 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
140 if (initialize_platform && !platform->Initialized()) {
141 SE_RETURN_IF_ERROR(platform->Initialize({}));
142 }
143
144 return platform;
145 }
146
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)147 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
148 absl::string_view target,
149 const std::map<std::string, std::string>& options) {
150 absl::MutexLock lock(&mu_);
151
152 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
153 if (platform->Initialized()) {
154 return port::Status(
155 port::error::FAILED_PRECONDITION,
156 absl::StrCat("platform \"", target, "\" is already initialized"));
157 }
158
159 SE_RETURN_IF_ERROR(platform->Initialize(options));
160
161 return platform;
162 }
163
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)164 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
165 const Platform::Id& id, const std::map<std::string, std::string>& options) {
166 absl::MutexLock lock(&mu_);
167
168 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
169 if (platform->Initialized()) {
170 return port::Status(
171 port::error::FAILED_PRECONDITION,
172 absl::StrFormat("platform with id %p is already initialized", id));
173 }
174
175 SE_RETURN_IF_ERROR(platform->Initialize(options));
176
177 return platform;
178 }
179
RegisterListener(std::unique_ptr<Listener> listener)180 port::Status MultiPlatformManagerImpl::RegisterListener(
181 std::unique_ptr<Listener> listener) {
182 absl::MutexLock lock(&mu_);
183 CHECK(id_map_.empty());
184 CHECK(name_map_.empty());
185 listeners_.push_back(std::move(listener));
186 return port::Status::OK();
187 }
188
189 port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)190 MultiPlatformManagerImpl::PlatformsWithFilter(
191 const std::function<bool(const Platform*)>& filter,
192 bool initialize_platform) {
193 absl::MutexLock lock(&mu_);
194 CHECK_EQ(id_map_.size(), name_map_.size());
195 std::vector<Platform*> platforms;
196 platforms.reserve(id_map_.size());
197 for (const auto& entry : id_map_) {
198 Platform* platform = entry.second;
199 if (filter(platform)) {
200 if (initialize_platform && !platform->Initialized()) {
201 SE_RETURN_IF_ERROR(platform->Initialize({}));
202 }
203 platforms.push_back(platform);
204 }
205 }
206 return platforms;
207 }
208
209 std::vector<std::string>
InitializedPlatformNamesWithFilter(const std::function<bool (const Platform *)> & filter)210 MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
211 const std::function<bool(const Platform*)>& filter) {
212 CHECK_EQ(id_map_.size(), name_map_.size());
213 std::vector<std::string> initialized_platforms_names;
214 initialized_platforms_names.reserve(id_map_.size());
215 for (const auto& entry : id_map_) {
216 Platform* platform = entry.second;
217 if (filter(platform)) {
218 if (platform->Initialized()) {
219 initialized_platforms_names.push_back(platform->Name());
220 }
221 }
222 }
223 return initialized_platforms_names;
224 }
225
LookupByNameLocked(absl::string_view target)226 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
227 absl::string_view target) {
228 auto it = name_map_.find(absl::AsciiStrToLower(target));
229 if (it == name_map_.end()) {
230 return port::Status(
231 port::error::NOT_FOUND,
232 absl::StrCat("Could not find registered platform with name: \"", target,
233 "\". Available platform names are: ",
234 absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
235 }
236 return it->second;
237 }
238
LookupByIdLocked(const Platform::Id & id)239 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
240 const Platform::Id& id) {
241 auto it = id_map_.find(id);
242 if (it == id_map_.end()) {
243 return port::Status(
244 port::error::NOT_FOUND,
245 absl::StrFormat("could not find registered platform with id: %p", id));
246 }
247 return it->second;
248 }
249
Impl()250 MultiPlatformManagerImpl& Impl() {
251 static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
252 return *impl;
253 }
254
255 } // namespace
256
RegisterPlatform(std::unique_ptr<Platform> platform)257 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
258 std::unique_ptr<Platform> platform) {
259 return Impl().RegisterPlatform(std::move(platform));
260 }
261
PlatformWithName(absl::string_view target)262 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
263 absl::string_view target) {
264 return Impl().PlatformWithName(target);
265 }
266
PlatformWithId(const Platform::Id & id)267 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
268 const Platform::Id& id) {
269 return Impl().PlatformWithId(id);
270 }
271
PlatformWithId(const Platform::Id & id,bool initialize_platform)272 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
273 const Platform::Id& id, bool initialize_platform) {
274 return Impl().PlatformWithId(id, initialize_platform);
275 }
276
PlatformWithName(absl::string_view target,bool initialize_platform)277 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
278 absl::string_view target, bool initialize_platform) {
279 return Impl().PlatformWithName(target, initialize_platform);
280 }
281
282 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)283 MultiPlatformManager::InitializePlatformWithName(
284 absl::string_view target,
285 const std::map<std::string, std::string>& options) {
286 return Impl().InitializePlatformWithName(target, options);
287 }
288
289 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)290 MultiPlatformManager::InitializePlatformWithId(
291 const Platform::Id& id, const std::map<std::string, std::string>& options) {
292 return Impl().InitializePlatformWithId(id, options);
293 }
294
RegisterListener(std::unique_ptr<Listener> listener)295 /*static*/ port::Status MultiPlatformManager::RegisterListener(
296 std::unique_ptr<Listener> listener) {
297 return Impl().RegisterListener(std::move(listener));
298 }
299
300 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter)301 MultiPlatformManager::PlatformsWithFilter(
302 const std::function<bool(const Platform*)>& filter) {
303 return PlatformsWithFilter(filter, /*initialize_platform=*/true);
304 }
305
306 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)307 MultiPlatformManager::PlatformsWithFilter(
308 const std::function<bool(const Platform*)>& filter,
309 bool initialize_platform) {
310 return Impl().PlatformsWithFilter(filter, initialize_platform);
311 }
312
313 } // namespace stream_executor
314
315 REGISTER_MODULE_INITIALIZER(
316 multi_platform_manager,
317 {
318 // Nothing -- this is just a module initializer
319 // definition to reference for sequencing
320 // purposes from Platform subclasses that register
321 // themselves with the MultiPlatformManager.
322 });
323
324 REGISTER_MODULE_INITIALIZER(
325 multi_platform_manager_listener,
326 {
327 // Nothing -- this is just a module initializer definition to reference
328 // for sequencing registration of listeners with the
329 // MultiPlatformManager.
330 });
331
332 // Listener registration should happen before platform registration.
333 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
334 multi_platform_manager);
335