1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/multi_platform_manager.h"
17
18 #include "absl/base/thread_annotations.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/stream_executor/lib/error.h"
23 #include "tensorflow/stream_executor/lib/initialize.h"
24 #include "tensorflow/stream_executor/lib/str_util.h"
25 #include "tensorflow/stream_executor/lib/stringprintf.h"
26
27 namespace stream_executor {
28 namespace {
29
30 class MultiPlatformManagerImpl {
31 public:
32 port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
33 LOCKS_EXCLUDED(mu_);
34
35 port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
36 LOCKS_EXCLUDED(mu_);
37
38 port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
39 LOCKS_EXCLUDED(mu_);
40
41 port::StatusOr<Platform*> InitializePlatformWithName(
42 absl::string_view target, const std::map<string, string>& options)
43 LOCKS_EXCLUDED(mu_);
44 port::StatusOr<Platform*> InitializePlatformWithId(
45 const Platform::Id& id, const std::map<string, string>& options)
46 LOCKS_EXCLUDED(mu_);
47
48 std::vector<Platform*> AllPlatforms() LOCKS_EXCLUDED(mu_);
49
50 using Listener = MultiPlatformManager::Listener;
51 port::Status RegisterListener(std::unique_ptr<Listener> listener)
52 LOCKS_EXCLUDED(mu_);
53
54 private:
55 // Looks up the platform object with the given name. Assumes the Platforms
56 // mutex is held.
57 port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
58 EXCLUSIVE_LOCKS_REQUIRED(mu_);
59
60 // Looks up the platform object with the given id. Assumes the Platforms
61 // mutex is held.
62 port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
63 EXCLUSIVE_LOCKS_REQUIRED(mu_);
64
65 absl::Mutex mu_;
66 std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_);
67 absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_);
68 absl::flat_hash_map<string, Platform*> name_map_ GUARDED_BY(mu_);
69 };
70
RegisterPlatform(std::unique_ptr<Platform> platform)71 port::Status MultiPlatformManagerImpl::RegisterPlatform(
72 std::unique_ptr<Platform> platform) {
73 CHECK(platform != nullptr);
74 string key = port::Lowercase(platform->Name());
75 absl::MutexLock lock(&mu_);
76 if (name_map_.find(key) != name_map_.end()) {
77 return port::Status(port::error::INTERNAL,
78 "platform is already registered with name: \"" +
79 platform->Name() + "\"");
80 }
81 Platform* platform_ptr = platform.get();
82 CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
83 // Release ownership/uniqueness to prevent destruction on program exit.
84 // This avoids Platforms "cleaning up" on program exit, because otherwise,
85 // there are _very_ tricky races between StreamExecutor and underlying
86 // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
87 // program, these are deemed acceptable.
88 name_map_[key] = platform.release();
89 for (const auto& listener : listeners_) {
90 listener->PlatformRegistered(platform_ptr);
91 }
92 return port::Status::OK();
93 }
94
PlatformWithName(absl::string_view target)95 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
96 absl::string_view target) {
97 absl::MutexLock lock(&mu_);
98
99 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
100 if (!platform->Initialized()) {
101 SE_RETURN_IF_ERROR(platform->Initialize({}));
102 }
103
104 return platform;
105 }
106
PlatformWithId(const Platform::Id & id)107 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
108 const Platform::Id& id) {
109 absl::MutexLock lock(&mu_);
110
111 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
112 if (!platform->Initialized()) {
113 SE_RETURN_IF_ERROR(platform->Initialize({}));
114 }
115
116 return platform;
117 }
118
InitializePlatformWithName(absl::string_view target,const std::map<string,string> & options)119 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
120 absl::string_view target, const std::map<string, string>& options) {
121 absl::MutexLock lock(&mu_);
122
123 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
124 if (platform->Initialized()) {
125 return port::Status(
126 port::error::FAILED_PRECONDITION,
127 absl::StrCat("platform \"", target, "\" is already initialized"));
128 }
129
130 SE_RETURN_IF_ERROR(platform->Initialize(options));
131
132 return platform;
133 }
134
InitializePlatformWithId(const Platform::Id & id,const std::map<string,string> & options)135 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
136 const Platform::Id& id, const std::map<string, string>& options) {
137 absl::MutexLock lock(&mu_);
138
139 SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
140 if (platform->Initialized()) {
141 return port::Status(
142 port::error::FAILED_PRECONDITION,
143 port::Printf("platform with id 0x%p is already initialized", id));
144 }
145
146 SE_RETURN_IF_ERROR(platform->Initialize(options));
147
148 return platform;
149 }
150
RegisterListener(std::unique_ptr<Listener> listener)151 port::Status MultiPlatformManagerImpl::RegisterListener(
152 std::unique_ptr<Listener> listener) {
153 absl::MutexLock lock(&mu_);
154 CHECK(id_map_.empty());
155 CHECK(name_map_.empty());
156 listeners_.push_back(std::move(listener));
157 return port::Status::OK();
158 }
159
AllPlatforms()160 std::vector<Platform*> MultiPlatformManagerImpl::AllPlatforms() {
161 absl::MutexLock lock(&mu_);
162 CHECK_EQ(id_map_.size(), name_map_.size());
163 std::vector<Platform*> platforms;
164 platforms.reserve(id_map_.size());
165 for (const auto& entry : id_map_) {
166 platforms.push_back(entry.second);
167 }
168 return platforms;
169 }
170
LookupByNameLocked(absl::string_view target)171 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
172 absl::string_view target) {
173 auto it = name_map_.find(port::Lowercase(target));
174 if (it == name_map_.end()) {
175 return port::Status(
176 port::error::NOT_FOUND,
177 absl::StrCat("Could not find registered platform with name: \"", target,
178 "\""));
179 }
180 return it->second;
181 }
182
LookupByIdLocked(const Platform::Id & id)183 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
184 const Platform::Id& id) {
185 auto it = id_map_.find(id);
186 if (it == id_map_.end()) {
187 return port::Status(
188 port::error::NOT_FOUND,
189 port::Printf("could not find registered platform with id: 0x%p", id));
190 }
191 return it->second;
192 }
193
Impl()194 MultiPlatformManagerImpl& Impl() {
195 static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
196 return *impl;
197 }
198
199 } // namespace
200
RegisterPlatform(std::unique_ptr<Platform> platform)201 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
202 std::unique_ptr<Platform> platform) {
203 return Impl().RegisterPlatform(std::move(platform));
204 }
205
PlatformWithName(absl::string_view target)206 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
207 absl::string_view target) {
208 return Impl().PlatformWithName(target);
209 }
210
PlatformWithId(const Platform::Id & id)211 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
212 const Platform::Id& id) {
213 return Impl().PlatformWithId(id);
214 }
215
216 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<string,string> & options)217 MultiPlatformManager::InitializePlatformWithName(
218 absl::string_view target, const std::map<string, string>& options) {
219 return Impl().InitializePlatformWithName(target, options);
220 }
221
222 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<string,string> & options)223 MultiPlatformManager::InitializePlatformWithId(
224 const Platform::Id& id, const std::map<string, string>& options) {
225 return Impl().InitializePlatformWithId(id, options);
226 }
227
RegisterListener(std::unique_ptr<Listener> listener)228 /*static*/ port::Status MultiPlatformManager::RegisterListener(
229 std::unique_ptr<Listener> listener) {
230 return Impl().RegisterListener(std::move(listener));
231 }
232
AllPlatforms()233 /*static*/ std::vector<Platform*> MultiPlatformManager::AllPlatforms() {
234 return Impl().AllPlatforms();
235 }
236
237 } // namespace stream_executor
238
239 REGISTER_MODULE_INITIALIZER(
240 multi_platform_manager,
241 {
242 // Nothing -- this is just a module initializer
243 // definition to reference for sequencing
244 // purposes from Platform subclasses that register
245 // themselves with the MultiPlatformManager.
246 });
247
248 REGISTER_MODULE_INITIALIZER(
249 multi_platform_manager_listener,
250 {
251 // Nothing -- this is just a module initializer definition to reference
252 // for sequencing registration of listeners with the
253 // MultiPlatformManager.
254 });
255
256 // Listener registration should happen before platform registration.
257 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
258 multi_platform_manager);
259