• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/multi_platform_manager.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/ascii.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/synchronization/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/stream_executor/lib/error.h"
26 #include "tensorflow/stream_executor/lib/initialize.h"
27 
28 namespace stream_executor {
29 namespace {
30 
31 class MultiPlatformManagerImpl {
32  public:
33   port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
34       TF_LOCKS_EXCLUDED(mu_);
35 
36   port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
37       TF_LOCKS_EXCLUDED(mu_);
38 
39   port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
40       TF_LOCKS_EXCLUDED(mu_);
41 
42   port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
43                                              bool initialize_platform)
44       TF_LOCKS_EXCLUDED(mu_);
45 
46   port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
47                                            bool initialize_platform)
48       TF_LOCKS_EXCLUDED(mu_);
49 
50   port::StatusOr<Platform*> InitializePlatformWithName(
51       absl::string_view target,
52       const std::map<std::string, std::string>& options) TF_LOCKS_EXCLUDED(mu_);
53   port::StatusOr<Platform*> InitializePlatformWithId(
54       const Platform::Id& id, const std::map<std::string, std::string>& options)
55       TF_LOCKS_EXCLUDED(mu_);
56 
57   port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
58       const std::function<bool(const Platform*)>& filter,
59       bool initialize_platform) TF_LOCKS_EXCLUDED(mu_);
60 
61   using Listener = MultiPlatformManager::Listener;
62   port::Status RegisterListener(std::unique_ptr<Listener> listener)
63       TF_LOCKS_EXCLUDED(mu_);
64 
65  private:
66   // Looks up the platform object with the given name.  Assumes the Platforms
67   // mutex is held.
68   port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
69       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
70 
71   // Looks up the platform object with the given id.  Assumes the Platforms
72   // mutex is held.
73   port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
74       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
75 
76   // Returns the names of the initialied platforms satisfying the given filter.
77   // By default, it will return all initialized platform names.
78   std::vector<std::string> InitializedPlatformNamesWithFilter(
__anon233c95820202(const Platform*) 79       const std::function<bool(const Platform*)>& filter = [](const Platform*) {
80         return true;
81       }) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
82 
83   absl::Mutex mu_;
84   std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
85   absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
86   absl::flat_hash_map<std::string, Platform*> name_map_ TF_GUARDED_BY(mu_);
87 };
88 
RegisterPlatform(std::unique_ptr<Platform> platform)89 port::Status MultiPlatformManagerImpl::RegisterPlatform(
90     std::unique_ptr<Platform> platform) {
91   CHECK(platform != nullptr);
92   std::string key = absl::AsciiStrToLower(platform->Name());
93   absl::MutexLock lock(&mu_);
94   if (name_map_.find(key) != name_map_.end()) {
95     return port::Status(port::error::INTERNAL,
96                         "platform is already registered with name: \"" +
97                             platform->Name() + "\"");
98   }
99   Platform* platform_ptr = platform.get();
100   CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
101   // Release ownership/uniqueness to prevent destruction on program exit.
102   // This avoids Platforms "cleaning up" on program exit, because otherwise,
103   // there are _very_ tricky races between StreamExecutor and underlying
104   // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
105   // program, these are deemed acceptable.
106   name_map_[key] = platform.release();
107   for (const auto& listener : listeners_) {
108     listener->PlatformRegistered(platform_ptr);
109   }
110   return port::Status::OK();
111 }
112 
PlatformWithName(absl::string_view target)113 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
114     absl::string_view target) {
115   return PlatformWithName(target, /*initialize_platform=*/true);
116 }
117 
PlatformWithId(const Platform::Id & id)118 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
119     const Platform::Id& id) {
120   return PlatformWithId(id, /*initialize_platform=*/true);
121 }
122 
PlatformWithName(absl::string_view target,bool initialize_platform)123 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
124     absl::string_view target, bool initialize_platform) {
125   absl::MutexLock lock(&mu_);
126 
127   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
128   if (initialize_platform && !platform->Initialized()) {
129     SE_RETURN_IF_ERROR(platform->Initialize({}));
130   }
131 
132   return platform;
133 }
134 
PlatformWithId(const Platform::Id & id,bool initialize_platform)135 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
136     const Platform::Id& id, bool initialize_platform) {
137   absl::MutexLock lock(&mu_);
138 
139   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
140   if (initialize_platform && !platform->Initialized()) {
141     SE_RETURN_IF_ERROR(platform->Initialize({}));
142   }
143 
144   return platform;
145 }
146 
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)147 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
148     absl::string_view target,
149     const std::map<std::string, std::string>& options) {
150   absl::MutexLock lock(&mu_);
151 
152   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
153   if (platform->Initialized()) {
154     return port::Status(
155         port::error::FAILED_PRECONDITION,
156         absl::StrCat("platform \"", target, "\" is already initialized"));
157   }
158 
159   SE_RETURN_IF_ERROR(platform->Initialize(options));
160 
161   return platform;
162 }
163 
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)164 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
165     const Platform::Id& id, const std::map<std::string, std::string>& options) {
166   absl::MutexLock lock(&mu_);
167 
168   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
169   if (platform->Initialized()) {
170     return port::Status(
171         port::error::FAILED_PRECONDITION,
172         absl::StrFormat("platform with id %p is already initialized", id));
173   }
174 
175   SE_RETURN_IF_ERROR(platform->Initialize(options));
176 
177   return platform;
178 }
179 
RegisterListener(std::unique_ptr<Listener> listener)180 port::Status MultiPlatformManagerImpl::RegisterListener(
181     std::unique_ptr<Listener> listener) {
182   absl::MutexLock lock(&mu_);
183   CHECK(id_map_.empty());
184   CHECK(name_map_.empty());
185   listeners_.push_back(std::move(listener));
186   return port::Status::OK();
187 }
188 
189 port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)190 MultiPlatformManagerImpl::PlatformsWithFilter(
191     const std::function<bool(const Platform*)>& filter,
192     bool initialize_platform) {
193   absl::MutexLock lock(&mu_);
194   CHECK_EQ(id_map_.size(), name_map_.size());
195   std::vector<Platform*> platforms;
196   platforms.reserve(id_map_.size());
197   for (const auto& entry : id_map_) {
198     Platform* platform = entry.second;
199     if (filter(platform)) {
200       if (initialize_platform && !platform->Initialized()) {
201         SE_RETURN_IF_ERROR(platform->Initialize({}));
202       }
203       platforms.push_back(platform);
204     }
205   }
206   return platforms;
207 }
208 
209 std::vector<std::string>
InitializedPlatformNamesWithFilter(const std::function<bool (const Platform *)> & filter)210 MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
211     const std::function<bool(const Platform*)>& filter) {
212   CHECK_EQ(id_map_.size(), name_map_.size());
213   std::vector<std::string> initialized_platforms_names;
214   initialized_platforms_names.reserve(id_map_.size());
215   for (const auto& entry : id_map_) {
216     Platform* platform = entry.second;
217     if (filter(platform)) {
218       if (platform->Initialized()) {
219         initialized_platforms_names.push_back(platform->Name());
220       }
221     }
222   }
223   return initialized_platforms_names;
224 }
225 
LookupByNameLocked(absl::string_view target)226 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
227     absl::string_view target) {
228   auto it = name_map_.find(absl::AsciiStrToLower(target));
229   if (it == name_map_.end()) {
230     return port::Status(
231         port::error::NOT_FOUND,
232         absl::StrCat("Could not find registered platform with name: \"", target,
233                      "\". Available platform names are: ",
234                      absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
235   }
236   return it->second;
237 }
238 
LookupByIdLocked(const Platform::Id & id)239 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
240     const Platform::Id& id) {
241   auto it = id_map_.find(id);
242   if (it == id_map_.end()) {
243     return port::Status(
244         port::error::NOT_FOUND,
245         absl::StrFormat("could not find registered platform with id: %p", id));
246   }
247   return it->second;
248 }
249 
Impl()250 MultiPlatformManagerImpl& Impl() {
251   static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
252   return *impl;
253 }
254 
255 }  // namespace
256 
RegisterPlatform(std::unique_ptr<Platform> platform)257 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
258     std::unique_ptr<Platform> platform) {
259   return Impl().RegisterPlatform(std::move(platform));
260 }
261 
PlatformWithName(absl::string_view target)262 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
263     absl::string_view target) {
264   return Impl().PlatformWithName(target);
265 }
266 
PlatformWithId(const Platform::Id & id)267 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
268     const Platform::Id& id) {
269   return Impl().PlatformWithId(id);
270 }
271 
PlatformWithId(const Platform::Id & id,bool initialize_platform)272 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
273     const Platform::Id& id, bool initialize_platform) {
274   return Impl().PlatformWithId(id, initialize_platform);
275 }
276 
PlatformWithName(absl::string_view target,bool initialize_platform)277 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
278     absl::string_view target, bool initialize_platform) {
279   return Impl().PlatformWithName(target, initialize_platform);
280 }
281 
282 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)283 MultiPlatformManager::InitializePlatformWithName(
284     absl::string_view target,
285     const std::map<std::string, std::string>& options) {
286   return Impl().InitializePlatformWithName(target, options);
287 }
288 
289 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)290 MultiPlatformManager::InitializePlatformWithId(
291     const Platform::Id& id, const std::map<std::string, std::string>& options) {
292   return Impl().InitializePlatformWithId(id, options);
293 }
294 
RegisterListener(std::unique_ptr<Listener> listener)295 /*static*/ port::Status MultiPlatformManager::RegisterListener(
296     std::unique_ptr<Listener> listener) {
297   return Impl().RegisterListener(std::move(listener));
298 }
299 
300 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter)301 MultiPlatformManager::PlatformsWithFilter(
302     const std::function<bool(const Platform*)>& filter) {
303   return PlatformsWithFilter(filter, /*initialize_platform=*/true);
304 }
305 
306 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)307 MultiPlatformManager::PlatformsWithFilter(
308     const std::function<bool(const Platform*)>& filter,
309     bool initialize_platform) {
310   return Impl().PlatformsWithFilter(filter, initialize_platform);
311 }
312 
313 }  // namespace stream_executor
314 
315 REGISTER_MODULE_INITIALIZER(
316     multi_platform_manager,
317     {
318         // Nothing -- this is just a module initializer
319         // definition to reference for sequencing
320         // purposes from Platform subclasses that register
321         // themselves with the MultiPlatformManager.
322     });
323 
324 REGISTER_MODULE_INITIALIZER(
325     multi_platform_manager_listener,
326     {
327         // Nothing -- this is just a module initializer definition to reference
328         // for sequencing registration of listeners with the
329         // MultiPlatformManager.
330     });
331 
332 // Listener registration should happen before platform registration.
333 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
334                                      multi_platform_manager);
335