1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/platform_util.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/debug_options_flags.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33
34 namespace xla {
35
36 // Minimum supported CUDA compute capability is 3.5.
37 constexpr int kMinCudaComputeCapabilityMajor = 3;
38 constexpr int kMinCudaComputeCapabilityMinor = 5;
39
40 // The name of the interpreter platform.
41 constexpr char kInterpreter[] = "interpreter";
42
43 namespace {
44
CanonicalPlatformName(const string & name)45 string CanonicalPlatformName(const string& name) {
46 string platform_str = absl::AsciiStrToLower(name);
47 // "cpu" and "host" mean the same thing.
48 if (platform_str == "cpu") {
49 platform_str = "host";
50 }
51 // "gpu" and "cuda" mean the same thing.
52 if (platform_str == "gpu") {
53 platform_str = "cuda";
54 }
55 return platform_str;
56 }
57
58 } // namespace
59
60 /* static */ StatusOr<std::vector<se::Platform*>>
GetSupportedPlatforms()61 PlatformUtil::GetSupportedPlatforms() {
62 std::vector<se::Platform*> all_platforms =
63 se::MultiPlatformManager::AllPlatforms();
64 if (all_platforms.empty()) {
65 LOG(WARNING) << "no executor platforms available: platform map is empty";
66 }
67
68 // Gather all platforms which have an XLA compiler.
69 std::vector<se::Platform*> platforms;
70 for (se::Platform* platform : all_platforms) {
71 auto compiler_status = Compiler::GetForPlatform(platform);
72 if (compiler_status.ok()) {
73 if (!platform->Initialized()) {
74 TF_RETURN_IF_ERROR(platform->Initialize({}));
75 }
76 platforms.push_back(platform);
77 } else {
78 LOG(INFO) << "platform " << platform->Name() << " present but no "
79 << "XLA compiler available: "
80 << compiler_status.status().error_message();
81 }
82 }
83 return platforms;
84 }
85
GetSolePlatform()86 /* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
87 TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
88 if (platforms.empty()) {
89 return NotFound("no platforms found");
90 } else if (platforms.size() == 1) {
91 se::Platform* platform = platforms[0];
92 if (!platform->Initialized()) {
93 TF_RETURN_IF_ERROR(platform->Initialize({}));
94 }
95 return platform;
96 }
97
98 // Multiple platforms present and we can't pick a reasonable default.
99 string platforms_string = absl::StrJoin(
100 platforms, ", ",
101 [](string* out, const se::Platform* p) { out->append(p->Name()); });
102 return InvalidArgument(
103 "must specify platform because more than one platform found: %s",
104 platforms_string);
105 }
106
GetDefaultPlatform()107 /* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
108 TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
109
110 se::Platform* platform = nullptr;
111 if (platforms.empty()) {
112 return NotFound("no platforms found");
113 } else if (platforms.size() == 1) {
114 platform = platforms[0];
115 } else if (platforms.size() == 2) {
116 for (int i = 0; i < 2; i++) {
117 if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
118 absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
119 platform = platforms[1 - i];
120 break;
121 }
122 }
123 }
124 if (platform != nullptr) {
125 if (!platform->Initialized()) {
126 TF_RETURN_IF_ERROR(platform->Initialize({}));
127 }
128 return platform;
129 }
130
131 // Multiple platforms present and we can't pick a reasonable default.
132 string platforms_string = absl::StrJoin(
133 platforms, ", ",
134 [](string* out, const se::Platform* p) { out->append(p->Name()); });
135 return InvalidArgument(
136 "must specify platform because more than one platform (except for the "
137 "interpreter platform) found: %s",
138 platforms_string);
139 }
140
GetPlatform(const string & platform_name)141 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
142 const string& platform_name) {
143 string platform_str = CanonicalPlatformName(platform_name);
144 TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
145 for (se::Platform* platform : platforms) {
146 if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
147 if (!platform->Initialized()) {
148 TF_RETURN_IF_ERROR(platform->Initialize({}));
149 }
150 return platform;
151 }
152 }
153 return InvalidArgument("platform %s not found", platform_name);
154 }
155
GetPlatformExceptFor(const string & platform_name)156 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
157 const string& platform_name) {
158 string platform_str = CanonicalPlatformName(platform_name);
159
160 TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
161 std::vector<se::Platform*> matched;
162 for (se::Platform* platform : platforms) {
163 if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
164 matched.push_back(platform);
165 }
166 }
167 if (matched.empty()) {
168 return InvalidArgument("unable to find platform that is not %s",
169 platform_name);
170 }
171 if (matched.size() == 1) {
172 auto platform = matched[0];
173 if (!platform->Initialized()) {
174 TF_RETURN_IF_ERROR(platform->Initialize({}));
175 }
176 return platform;
177 }
178 string matched_string = absl::StrJoin(
179 matched, ", ",
180 [](string* out, const se::Platform* p) { out->append(p->Name()); });
181 return InvalidArgument(
182 "found multiple platforms %s, but expected one platform except for %s",
183 matched_string, platform_name);
184 }
185
186 // Returns whether the device underlying the given StreamExecutor is supported
187 // by XLA.
IsDeviceSupported(se::StreamExecutor * executor)188 static bool IsDeviceSupported(se::StreamExecutor* executor) {
189 const auto& description = executor->GetDeviceDescription();
190 if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
191 // CUDA devices must have a minimum compute capability.
192 int major_version, minor_version;
193 if (description.cuda_compute_capability(&major_version, &minor_version)) {
194 if (major_version < kMinCudaComputeCapabilityMajor ||
195 (major_version == kMinCudaComputeCapabilityMajor &&
196 minor_version < kMinCudaComputeCapabilityMinor)) {
197 LOG(INFO) << "StreamExecutor cuda device ("
198 << executor->device_ordinal() << ") is of "
199 << "insufficient compute capability: "
200 << kMinCudaComputeCapabilityMajor << "."
201 << kMinCudaComputeCapabilityMinor << " required, "
202 << "device is " << major_version << "." << minor_version;
203 return false;
204 }
205 }
206 }
207 return true;
208 }
209
210 /* static */ StatusOr<std::vector<se::StreamExecutor*>>
GetStreamExecutors(se::Platform * platform,const absl::optional<std::set<int>> & allowed_devices)211 PlatformUtil::GetStreamExecutors(
212 se::Platform* platform,
213 const absl::optional<std::set<int>>& allowed_devices) {
214 int device_count = platform->VisibleDeviceCount();
215 if (device_count <= 0) {
216 return NotFound("no %s devices found", platform->Name());
217 }
218 if (platform->id() == se::host::kHostPlatformId) {
219 // On host "devices", StreamExecutor exports a device for each hardware
220 // thread. Because we parallelize a single computation across threads, it
221 // doesn't make sense to expose these as separate devices, so by default we
222 // fix the number of devices to one. However we do let the user override
223 // this behavior to help run tests on the host that run models in parallel
224 // across multiple devices.
225 device_count =
226 GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
227 }
228 std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
229 VLOG(1) << "Initializing devices";
230 {
231 tensorflow::thread::ThreadPool thread_pool(
232 tensorflow::Env::Default(), "device_initialization", device_count);
233 for (int i = 0; i < device_count; ++i) {
234 // Once a stream executor is instantiated it will cause allocations on
235 // the device, for example for GPUs cuda context, cudnn handles etc. will
236 // be constructed. By constructing stream executors only on the
237 // allowed_devices, we don't make any allocations on other devices.
238 // This helps in multi-process executions on the same host like horovod or
239 // shared hosts.
240 if (allowed_devices && allowed_devices->count(i) == 0) {
241 VLOG(1) << "Not initializing StreamExecutor for device " << i
242 << " since it is not in the visible device list";
243 continue;
244 }
245 thread_pool.Schedule([platform, i, &stream_executors]() {
246 VLOG(1) << "Started device init " << i;
247 se::StreamExecutorConfig config;
248 config.ordinal = i;
249 auto executor_status = platform->GetExecutor(config);
250 if (executor_status.ok()) {
251 se::StreamExecutor* executor = executor_status.ValueOrDie();
252 if (IsDeviceSupported(executor)) {
253 stream_executors[i] = executor;
254 }
255 } else {
256 LOG(WARNING) << "unable to create StreamExecutor for "
257 << platform->Name() << ":" << i << ": "
258 << executor_status.status().error_message();
259 }
260 VLOG(1) << "Finished device init " << i;
261 });
262 }
263 // Block here in thread_pool destructor until all devices are initialized.
264 }
265 VLOG(1) << "Device initialization complete";
266 if (absl::c_all_of(stream_executors,
267 [](se::StreamExecutor* s) { return s == nullptr; })) {
268 return InternalError("no supported devices found for platform %s",
269 platform->Name());
270 }
271 return stream_executors;
272 }
273
274 } // namespace xla
275