• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/multi_platform_manager.h"
17 
18 #include "absl/base/thread_annotations.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/stream_executor/lib/error.h"
23 #include "tensorflow/stream_executor/lib/initialize.h"
24 #include "tensorflow/stream_executor/lib/str_util.h"
25 #include "tensorflow/stream_executor/lib/stringprintf.h"
26 
27 namespace stream_executor {
28 namespace {
29 
30 class MultiPlatformManagerImpl {
31  public:
32   port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
33       LOCKS_EXCLUDED(mu_);
34 
35   port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
36       LOCKS_EXCLUDED(mu_);
37 
38   port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
39       LOCKS_EXCLUDED(mu_);
40 
41   port::StatusOr<Platform*> InitializePlatformWithName(
42       absl::string_view target, const std::map<string, string>& options)
43       LOCKS_EXCLUDED(mu_);
44   port::StatusOr<Platform*> InitializePlatformWithId(
45       const Platform::Id& id, const std::map<string, string>& options)
46       LOCKS_EXCLUDED(mu_);
47 
48   std::vector<Platform*> AllPlatforms() LOCKS_EXCLUDED(mu_);
49 
50   using Listener = MultiPlatformManager::Listener;
51   port::Status RegisterListener(std::unique_ptr<Listener> listener)
52       LOCKS_EXCLUDED(mu_);
53 
54  private:
55   // Looks up the platform object with the given name.  Assumes the Platforms
56   // mutex is held.
57   port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
58       EXCLUSIVE_LOCKS_REQUIRED(mu_);
59 
60   // Looks up the platform object with the given id.  Assumes the Platforms
61   // mutex is held.
62   port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
63       EXCLUSIVE_LOCKS_REQUIRED(mu_);
64 
65   absl::Mutex mu_;
66   std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_);
67   absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_);
68   absl::flat_hash_map<string, Platform*> name_map_ GUARDED_BY(mu_);
69 };
70 
RegisterPlatform(std::unique_ptr<Platform> platform)71 port::Status MultiPlatformManagerImpl::RegisterPlatform(
72     std::unique_ptr<Platform> platform) {
73   CHECK(platform != nullptr);
74   string key = port::Lowercase(platform->Name());
75   absl::MutexLock lock(&mu_);
76   if (name_map_.find(key) != name_map_.end()) {
77     return port::Status(port::error::INTERNAL,
78                         "platform is already registered with name: \"" +
79                             platform->Name() + "\"");
80   }
81   Platform* platform_ptr = platform.get();
82   CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
83   // Release ownership/uniqueness to prevent destruction on program exit.
84   // This avoids Platforms "cleaning up" on program exit, because otherwise,
85   // there are _very_ tricky races between StreamExecutor and underlying
86   // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
87   // program, these are deemed acceptable.
88   name_map_[key] = platform.release();
89   for (const auto& listener : listeners_) {
90     listener->PlatformRegistered(platform_ptr);
91   }
92   return port::Status::OK();
93 }
94 
PlatformWithName(absl::string_view target)95 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
96     absl::string_view target) {
97   absl::MutexLock lock(&mu_);
98 
99   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
100   if (!platform->Initialized()) {
101     SE_RETURN_IF_ERROR(platform->Initialize({}));
102   }
103 
104   return platform;
105 }
106 
PlatformWithId(const Platform::Id & id)107 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
108     const Platform::Id& id) {
109   absl::MutexLock lock(&mu_);
110 
111   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
112   if (!platform->Initialized()) {
113     SE_RETURN_IF_ERROR(platform->Initialize({}));
114   }
115 
116   return platform;
117 }
118 
InitializePlatformWithName(absl::string_view target,const std::map<string,string> & options)119 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
120     absl::string_view target, const std::map<string, string>& options) {
121   absl::MutexLock lock(&mu_);
122 
123   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
124   if (platform->Initialized()) {
125     return port::Status(
126         port::error::FAILED_PRECONDITION,
127         absl::StrCat("platform \"", target, "\" is already initialized"));
128   }
129 
130   SE_RETURN_IF_ERROR(platform->Initialize(options));
131 
132   return platform;
133 }
134 
InitializePlatformWithId(const Platform::Id & id,const std::map<string,string> & options)135 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
136     const Platform::Id& id, const std::map<string, string>& options) {
137   absl::MutexLock lock(&mu_);
138 
139   SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
140   if (platform->Initialized()) {
141     return port::Status(
142         port::error::FAILED_PRECONDITION,
143         port::Printf("platform with id 0x%p is already initialized", id));
144   }
145 
146   SE_RETURN_IF_ERROR(platform->Initialize(options));
147 
148   return platform;
149 }
150 
RegisterListener(std::unique_ptr<Listener> listener)151 port::Status MultiPlatformManagerImpl::RegisterListener(
152     std::unique_ptr<Listener> listener) {
153   absl::MutexLock lock(&mu_);
154   CHECK(id_map_.empty());
155   CHECK(name_map_.empty());
156   listeners_.push_back(std::move(listener));
157   return port::Status::OK();
158 }
159 
AllPlatforms()160 std::vector<Platform*> MultiPlatformManagerImpl::AllPlatforms() {
161   absl::MutexLock lock(&mu_);
162   CHECK_EQ(id_map_.size(), name_map_.size());
163   std::vector<Platform*> platforms;
164   platforms.reserve(id_map_.size());
165   for (const auto& entry : id_map_) {
166     platforms.push_back(entry.second);
167   }
168   return platforms;
169 }
170 
LookupByNameLocked(absl::string_view target)171 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
172     absl::string_view target) {
173   auto it = name_map_.find(port::Lowercase(target));
174   if (it == name_map_.end()) {
175     return port::Status(
176         port::error::NOT_FOUND,
177         absl::StrCat("Could not find registered platform with name: \"", target,
178                      "\""));
179   }
180   return it->second;
181 }
182 
LookupByIdLocked(const Platform::Id & id)183 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
184     const Platform::Id& id) {
185   auto it = id_map_.find(id);
186   if (it == id_map_.end()) {
187     return port::Status(
188         port::error::NOT_FOUND,
189         port::Printf("could not find registered platform with id: 0x%p", id));
190   }
191   return it->second;
192 }
193 
Impl()194 MultiPlatformManagerImpl& Impl() {
195   static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
196   return *impl;
197 }
198 
199 }  // namespace
200 
RegisterPlatform(std::unique_ptr<Platform> platform)201 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
202     std::unique_ptr<Platform> platform) {
203   return Impl().RegisterPlatform(std::move(platform));
204 }
205 
PlatformWithName(absl::string_view target)206 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
207     absl::string_view target) {
208   return Impl().PlatformWithName(target);
209 }
210 
PlatformWithId(const Platform::Id & id)211 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
212     const Platform::Id& id) {
213   return Impl().PlatformWithId(id);
214 }
215 
216 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<string,string> & options)217 MultiPlatformManager::InitializePlatformWithName(
218     absl::string_view target, const std::map<string, string>& options) {
219   return Impl().InitializePlatformWithName(target, options);
220 }
221 
222 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<string,string> & options)223 MultiPlatformManager::InitializePlatformWithId(
224     const Platform::Id& id, const std::map<string, string>& options) {
225   return Impl().InitializePlatformWithId(id, options);
226 }
227 
RegisterListener(std::unique_ptr<Listener> listener)228 /*static*/ port::Status MultiPlatformManager::RegisterListener(
229     std::unique_ptr<Listener> listener) {
230   return Impl().RegisterListener(std::move(listener));
231 }
232 
AllPlatforms()233 /*static*/ std::vector<Platform*> MultiPlatformManager::AllPlatforms() {
234   return Impl().AllPlatforms();
235 }
236 
237 }  // namespace stream_executor
238 
239 REGISTER_MODULE_INITIALIZER(
240     multi_platform_manager,
241     {
242         // Nothing -- this is just a module initializer
243         // definition to reference for sequencing
244         // purposes from Platform subclasses that register
245         // themselves with the MultiPlatformManager.
246     });
247 
248 REGISTER_MODULE_INITIALIZER(
249     multi_platform_manager_listener,
250     {
251         // Nothing -- this is just a module initializer definition to reference
252         // for sequencing registration of listeners with the
253         // MultiPlatformManager.
254     });
255 
256 // Listener registration should happen before platform registration.
257 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
258                                      multi_platform_manager);
259