1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17
18 #include "tensorflow/compiler/xla/client/client_library.h"
19
20 namespace tensorflow {
21
ParseVisibleDeviceList(absl::string_view visible_device_list)22 xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
23 absl::string_view visible_device_list) {
24 std::set<int> gpu_ids;
25 if (visible_device_list.empty()) {
26 return {{absl::nullopt}};
27 }
28 const std::vector<string> visible_devices =
29 absl::StrSplit(visible_device_list, ',');
30 for (const string& platform_device_id_str : visible_devices) {
31 int32_t platform_device_id;
32 if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
33 return errors::InvalidArgument(
34 "Could not parse entry in 'visible_device_list': '",
35 platform_device_id_str,
36 "'. visible_device_list = ", visible_device_list);
37 }
38 gpu_ids.insert(platform_device_id);
39 }
40 return {{gpu_ids}};
41 }
42
BuildXlaCompilationCache(DeviceBase * device,FunctionLibraryRuntime * flr,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)43 Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr,
44 const XlaPlatformInfo& platform_info,
45 XlaCompilationCache** cache) {
46 if (platform_info.xla_device_metadata()) {
47 *cache = new XlaCompilationCache(
48 platform_info.xla_device_metadata()->client(),
49 platform_info.xla_device_metadata()->jit_device_type());
50 return Status::OK();
51 }
52
53 auto platform =
54 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
55 if (!platform.ok()) {
56 return platform.status();
57 }
58
59 StatusOr<xla::Compiler*> compiler_for_platform =
60 xla::Compiler::GetForPlatform(platform.ValueOrDie());
61 if (!compiler_for_platform.ok()) {
62 // In some rare cases (usually in unit tests with very small clusters) we
63 // may end up transforming an XLA cluster with at least one GPU operation
64 // (which would normally force the cluster to be compiled using XLA:GPU)
65 // into an XLA cluster with no GPU operations (i.e. containing only CPU
66 // operations). Such a cluster can fail compilation (in way that
67 // MarkForCompilation could not have detected) if the CPU JIT is not linked
68 // in.
69 //
70 // So bail out of _XlaCompile in this case, and let the executor handle the
71 // situation for us.
72 const Status& status = compiler_for_platform.status();
73 if (status.code() == error::NOT_FOUND) {
74 return errors::Unimplemented("Could not find compiler for platform ",
75 platform.ValueOrDie()->Name(), ": ",
76 status.ToString());
77 }
78 }
79
80 xla::LocalClientOptions client_options;
81 client_options.set_platform(platform.ValueOrDie());
82 client_options.set_intra_op_parallelism_threads(
83 device->tensorflow_cpu_worker_threads()->num_threads);
84
85 string allowed_gpus =
86 flr->config_proto()->gpu_options().visible_device_list();
87 TF_ASSIGN_OR_RETURN(absl::optional<std::set<int>> gpu_ids,
88 ParseVisibleDeviceList(allowed_gpus));
89 client_options.set_allowed_devices(gpu_ids);
90
91 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
92 if (!client.ok()) {
93 return client.status();
94 }
95 const XlaOpRegistry::DeviceRegistration* registration;
96 if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
97 ®istration)) {
98 return errors::InvalidArgument("No JIT device registered for ",
99 platform_info.device_type().type());
100 }
101 *cache = new XlaCompilationCache(
102 client.ValueOrDie(), DeviceType(registration->compilation_device_name));
103 return Status::OK();
104 }
105
XlaPlatformInfoFromDevice(DeviceBase * device_base)106 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
107 auto device = static_cast<Device*>(device_base);
108 se::Platform::Id platform_id = nullptr;
109 const XlaDevice::Metadata* xla_device_metadata = nullptr;
110 std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
111
112 if (device->device_type() == DEVICE_CPU) {
113 platform_id = se::host::kHostPlatformId;
114 } else if (device->device_type() == DEVICE_GPU) {
115 platform_id = device->tensorflow_gpu_device_info()
116 ->stream->parent()
117 ->platform()
118 ->id();
119 } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
120 .ok()) {
121 // If we are on an XlaDevice, use the underlying XLA platform's allocator
122 // directly. We could use the StreamExecutor's allocator which may
123 // theoretically be more correct, but XLA returns a nice OOM message in a
124 // Status and StreamExecutor does not.
125 //
126 // Importantly we can't use ctx->device()->GetAllocator() as the allocator
127 // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
128 // allocator that returns XlaTensor objects. The XlaCompiler needs a real
129 // allocator to allocate real buffers.
130 platform_id = xla_device_metadata->platform()->id();
131 custom_allocator =
132 xla_device_metadata->client()->backend().shared_memory_allocator();
133 }
134
135 return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
136 xla_device_metadata, custom_allocator);
137 }
138
GetAllocator(DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)139 std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
140 DeviceBase* device, se::Stream* stream,
141 const XlaPlatformInfo& platform_info) {
142 if (platform_info.custom_allocator()) {
143 return platform_info.custom_allocator();
144 }
145 auto* alloc = device->GetAllocator({});
146 if (!stream) {
147 // Stream is not set for the host platform.
148 se::Platform* platform =
149 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
150 .ValueOrDie();
151 return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
152 }
153 return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
154 }
155
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars)156 XlaCompiler::Options GenerateCompilerOptions(
157 const XlaCompilationCache& cache,
158 const FunctionLibraryRuntime& function_library, DeviceBase* device,
159 se::Stream* stream, const XlaPlatformInfo& platform_info,
160 bool has_ref_vars) {
161 XlaCompiler::Options options;
162 options.client = static_cast<xla::LocalClient*>(cache.client());
163 if (stream != nullptr) {
164 options.device_ordinal = stream->parent()->device_ordinal();
165 }
166 options.device_type = cache.device_type();
167 options.flib_def = function_library.GetFunctionLibraryDefinition();
168 options.graph_def_version = function_library.graph_def_version();
169 options.allow_cpu_custom_calls =
170 (platform_info.platform_id() == se::host::kHostPlatformId);
171 options.device_allocator = GetAllocator(device, stream, platform_info);
172 if (platform_info.xla_device_metadata()) {
173 options.shape_representation_fn =
174 platform_info.xla_device_metadata()->shape_representation_fn();
175 }
176 // If reference variables are not present in the graph, we can safely alias
177 // passthrough parameters without performing a copy.
178 options.alias_passthrough_params =
179 !has_ref_vars && !platform_info.is_on_xla_device();
180 return options;
181 }
182
183 } // namespace tensorflow
184