• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/platform_util.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/debug_options_flags.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 
34 namespace xla {
35 
36 // Minimum supported CUDA compute capability is 3.5.
37 constexpr int kMinCudaComputeCapabilityMajor = 3;
38 constexpr int kMinCudaComputeCapabilityMinor = 5;
39 
40 // The name of the interpreter platform.
41 constexpr char kInterpreter[] = "interpreter";
42 
43 namespace {
44 
CanonicalPlatformName(const string & name)45 string CanonicalPlatformName(const string& name) {
46   string platform_str = absl::AsciiStrToLower(name);
47   // "cpu" and "host" mean the same thing.
48   if (platform_str == "cpu") {
49     platform_str = "host";
50   }
51   // "gpu" and "cuda" mean the same thing.
52   if (platform_str == "gpu") {
53     platform_str = "cuda";
54   }
55   return platform_str;
56 }
57 
58 }  // namespace
59 
60 /* static */ StatusOr<std::vector<se::Platform*>>
GetSupportedPlatforms()61 PlatformUtil::GetSupportedPlatforms() {
62   std::vector<se::Platform*> all_platforms =
63       se::MultiPlatformManager::AllPlatforms();
64   if (all_platforms.empty()) {
65     LOG(WARNING) << "no executor platforms available: platform map is empty";
66   }
67 
68   // Gather all platforms which have an XLA compiler.
69   std::vector<se::Platform*> platforms;
70   for (se::Platform* platform : all_platforms) {
71     auto compiler_status = Compiler::GetForPlatform(platform);
72     if (compiler_status.ok()) {
73       if (!platform->Initialized()) {
74         TF_RETURN_IF_ERROR(platform->Initialize({}));
75       }
76       platforms.push_back(platform);
77     } else {
78       LOG(INFO) << "platform " << platform->Name() << " present but no "
79                 << "XLA compiler available: "
80                 << compiler_status.status().error_message();
81     }
82   }
83   return platforms;
84 }
85 
GetSolePlatform()86 /* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
87   TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
88   if (platforms.empty()) {
89     return NotFound("no platforms found");
90   } else if (platforms.size() == 1) {
91     se::Platform* platform = platforms[0];
92     if (!platform->Initialized()) {
93       TF_RETURN_IF_ERROR(platform->Initialize({}));
94     }
95     return platform;
96   }
97 
98   // Multiple platforms present and we can't pick a reasonable default.
99   string platforms_string = absl::StrJoin(
100       platforms, ", ",
101       [](string* out, const se::Platform* p) { out->append(p->Name()); });
102   return InvalidArgument(
103       "must specify platform because more than one platform found: %s",
104       platforms_string);
105 }
106 
GetDefaultPlatform()107 /* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
108   TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
109 
110   se::Platform* platform = nullptr;
111   if (platforms.empty()) {
112     return NotFound("no platforms found");
113   } else if (platforms.size() == 1) {
114     platform = platforms[0];
115   } else if (platforms.size() == 2) {
116     for (int i = 0; i < 2; i++) {
117       if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
118           absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
119         platform = platforms[1 - i];
120         break;
121       }
122     }
123   }
124   if (platform != nullptr) {
125     if (!platform->Initialized()) {
126       TF_RETURN_IF_ERROR(platform->Initialize({}));
127     }
128     return platform;
129   }
130 
131   // Multiple platforms present and we can't pick a reasonable default.
132   string platforms_string = absl::StrJoin(
133       platforms, ", ",
134       [](string* out, const se::Platform* p) { out->append(p->Name()); });
135   return InvalidArgument(
136       "must specify platform because more than one platform (except for the "
137       "interpreter platform) found: %s",
138       platforms_string);
139 }
140 
GetPlatform(const string & platform_name)141 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
142     const string& platform_name) {
143   string platform_str = CanonicalPlatformName(platform_name);
144   TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
145   for (se::Platform* platform : platforms) {
146     if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
147       if (!platform->Initialized()) {
148         TF_RETURN_IF_ERROR(platform->Initialize({}));
149       }
150       return platform;
151     }
152   }
153   return InvalidArgument("platform %s not found", platform_name);
154 }
155 
GetPlatformExceptFor(const string & platform_name)156 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
157     const string& platform_name) {
158   string platform_str = CanonicalPlatformName(platform_name);
159 
160   TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
161   std::vector<se::Platform*> matched;
162   for (se::Platform* platform : platforms) {
163     if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
164       matched.push_back(platform);
165     }
166   }
167   if (matched.empty()) {
168     return InvalidArgument("unable to find platform that is not %s",
169                            platform_name);
170   }
171   if (matched.size() == 1) {
172     auto platform = matched[0];
173     if (!platform->Initialized()) {
174       TF_RETURN_IF_ERROR(platform->Initialize({}));
175     }
176     return platform;
177   }
178   string matched_string = absl::StrJoin(
179       matched, ", ",
180       [](string* out, const se::Platform* p) { out->append(p->Name()); });
181   return InvalidArgument(
182       "found multiple platforms %s, but expected one platform except for %s",
183       matched_string, platform_name);
184 }
185 
186 // Returns whether the device underlying the given StreamExecutor is supported
187 // by XLA.
IsDeviceSupported(se::StreamExecutor * executor)188 static bool IsDeviceSupported(se::StreamExecutor* executor) {
189   const auto& description = executor->GetDeviceDescription();
190   if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
191     // CUDA devices must have a minimum compute capability.
192     int major_version, minor_version;
193     if (description.cuda_compute_capability(&major_version, &minor_version)) {
194       if (major_version < kMinCudaComputeCapabilityMajor ||
195           (major_version == kMinCudaComputeCapabilityMajor &&
196            minor_version < kMinCudaComputeCapabilityMinor)) {
197         LOG(INFO) << "StreamExecutor cuda device ("
198                   << executor->device_ordinal() << ") is of "
199                   << "insufficient compute capability: "
200                   << kMinCudaComputeCapabilityMajor << "."
201                   << kMinCudaComputeCapabilityMinor << " required, "
202                   << "device is " << major_version << "." << minor_version;
203         return false;
204       }
205     }
206   }
207   return true;
208 }
209 
210 /* static */ StatusOr<std::vector<se::StreamExecutor*>>
GetStreamExecutors(se::Platform * platform,const absl::optional<std::set<int>> & allowed_devices)211 PlatformUtil::GetStreamExecutors(
212     se::Platform* platform,
213     const absl::optional<std::set<int>>& allowed_devices) {
214   int device_count = platform->VisibleDeviceCount();
215   if (device_count <= 0) {
216     return NotFound("no %s devices found", platform->Name());
217   }
218   if (platform->id() == se::host::kHostPlatformId) {
219     // On host "devices", StreamExecutor exports a device for each hardware
220     // thread. Because we parallelize a single computation across threads, it
221     // doesn't make sense to expose these as separate devices, so by default we
222     // fix the number of devices to one.  However we do let the user override
223     // this behavior to help run tests on the host that run models in parallel
224     // across multiple devices.
225     device_count =
226         GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
227   }
228   std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
229   VLOG(1) << "Initializing devices";
230   {
231     tensorflow::thread::ThreadPool thread_pool(
232         tensorflow::Env::Default(), "device_initialization", device_count);
233     for (int i = 0; i < device_count; ++i) {
234       // Once a stream executor is instantiated it will cause allocations on
235       // the device, for example for GPUs cuda context, cudnn handles etc. will
236       // be constructed. By constructing stream executors only on the
237       // allowed_devices, we don't make any allocations on other devices.
238       // This helps in multi-process executions on the same host like horovod or
239       // shared hosts.
240       if (allowed_devices && allowed_devices->count(i) == 0) {
241         VLOG(1) << "Not initializing StreamExecutor for device " << i
242                 << " since it is not in the visible device list";
243         continue;
244       }
245       thread_pool.Schedule([platform, i, &stream_executors]() {
246         VLOG(1) << "Started device init " << i;
247         se::StreamExecutorConfig config;
248         config.ordinal = i;
249         auto executor_status = platform->GetExecutor(config);
250         if (executor_status.ok()) {
251           se::StreamExecutor* executor = executor_status.ValueOrDie();
252           if (IsDeviceSupported(executor)) {
253             stream_executors[i] = executor;
254           }
255         } else {
256           LOG(WARNING) << "unable to create StreamExecutor for "
257                        << platform->Name() << ":" << i << ": "
258                        << executor_status.status().error_message();
259         }
260         VLOG(1) << "Finished device init " << i;
261       });
262     }
263     // Block here in thread_pool destructor until all devices are initialized.
264   }
265   VLOG(1) << "Device initialization complete";
266   if (absl::c_all_of(stream_executors,
267                      [](se::StreamExecutor* s) { return s == nullptr; })) {
268     return InternalError("no supported devices found for platform %s",
269                          platform->Name());
270   }
271   return stream_executors;
272 }
273 
274 }  // namespace xla
275