1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17
18 #include "tensorflow/compiler/xla/client/client_library.h"
19
20 namespace tensorflow {
21
BuildXlaCompilationCache(DeviceBase * device,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)22 Status BuildXlaCompilationCache(DeviceBase* device,
23 const XlaPlatformInfo& platform_info,
24 XlaCompilationCache** cache) {
25 if (platform_info.xla_device_metadata()) {
26 *cache = new XlaCompilationCache(
27 platform_info.xla_device_metadata()->client(),
28 platform_info.xla_device_metadata()->jit_device_type());
29 return Status::OK();
30 }
31
32 auto platform =
33 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
34 if (!platform.ok()) {
35 return platform.status();
36 }
37
38 xla::StatusOr<xla::Compiler*> compiler_for_platform =
39 xla::Compiler::GetForPlatform(platform.ValueOrDie());
40 if (!compiler_for_platform.ok()) {
41 // In some rare cases (usually in unit tests with very small clusters) we
42 // may end up transforming an XLA cluster with at least one GPU operation
43 // (which would normally force the cluster to be compiled using XLA:GPU)
44 // into an XLA cluster with no GPU operations (i.e. containing only CPU
45 // operations). Such a cluster can fail compilation (in way that
46 // MarkForCompilation could not have detected) if the CPU JIT is not linked
47 // in.
48 //
49 // So bail out of _XlaCompile in this case, and let the executor handle the
50 // situation for us.
51 const Status& status = compiler_for_platform.status();
52 if (status.code() == error::NOT_FOUND) {
53 return errors::Unimplemented("Could not find compiler for platform ",
54 platform.ValueOrDie()->Name(), ": ",
55 status.ToString());
56 }
57 }
58
59 xla::LocalClientOptions client_options;
60 client_options.set_platform(platform.ValueOrDie());
61 client_options.set_intra_op_parallelism_threads(
62 device->tensorflow_cpu_worker_threads()->num_threads);
63 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
64 if (!client.ok()) {
65 return client.status();
66 }
67 const XlaOpRegistry::DeviceRegistration* registration;
68 if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
69 ®istration)) {
70 return errors::InvalidArgument("No JIT device registered for ",
71 platform_info.device_type().type());
72 }
73 *cache = new XlaCompilationCache(
74 client.ValueOrDie(), DeviceType(registration->compilation_device_name));
75 return Status::OK();
76 }
77
XlaPlatformInfoFromDevice(DeviceBase * device_base)78 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
79 auto device = static_cast<Device*>(device_base);
80 se::Platform::Id platform_id = nullptr;
81 const XlaDevice::Metadata* xla_device_metadata = nullptr;
82 se::DeviceMemoryAllocator* custom_allocator = nullptr;
83
84 if (device->device_type() == DEVICE_CPU) {
85 platform_id = se::host::kHostPlatformId;
86 } else if (device->device_type() == DEVICE_GPU) {
87 platform_id = device->tensorflow_gpu_device_info()
88 ->stream->parent()
89 ->platform()
90 ->id();
91 } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
92 .ok()) {
93 // If we are on an XlaDevice, use the underlying XLA platform's allocator
94 // directly. We could use the StreamExecutor's allocator which may
95 // theoretically be more correct, but XLA returns a nice OOM message in a
96 // Status and StreamExecutor does not.
97 //
98 // Importantly we can't use ctx->device()->GetAllocator() as the allocator
99 // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
100 // allocator that returns XlaTensor objects. The XlaCompiler needs a real
101 // allocator to allocate real buffers.
102 platform_id = xla_device_metadata->platform()->id();
103 custom_allocator =
104 xla_device_metadata->client()->backend().memory_allocator();
105 }
106
107 return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
108 xla_device_metadata, custom_allocator);
109 }
110
GetAllocator(absl::optional<se::TfAllocatorAdapter> * tf_allocator_adapter,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)111 se::DeviceMemoryAllocator* GetAllocator(
112 absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
113 DeviceBase* device, se::Stream* stream,
114 const XlaPlatformInfo& platform_info) {
115 if (platform_info.custom_allocator()) {
116 return platform_info.custom_allocator();
117 }
118 if (!stream) {
119 // Stream is not set for the host platform.
120 se::Platform* platform =
121 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
122 .ValueOrDie();
123 tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
124 return &tf_allocator_adapter->value();
125 }
126 tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
127 return &tf_allocator_adapter->value();
128 }
129
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars,absl::optional<se::TfAllocatorAdapter> * tf_allocator_adapter)130 XlaCompiler::Options GenerateCompilerOptions(
131 const XlaCompilationCache& cache,
132 const FunctionLibraryRuntime& function_library, DeviceBase* device,
133 se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
134 absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
135 XlaCompiler::Options options;
136 options.client = static_cast<xla::LocalClient*>(cache.client());
137 if (stream != nullptr) {
138 options.device_ordinal = stream->parent()->device_ordinal();
139 }
140 options.device_type = cache.device_type();
141 options.flib_def = function_library.GetFunctionLibraryDefinition();
142 options.graph_def_version = function_library.graph_def_version();
143 options.allow_cpu_custom_calls =
144 (platform_info.platform_id() == se::host::kHostPlatformId);
145 options.device_allocator =
146 GetAllocator(tf_allocator_adapter, device, stream, platform_info);
147 if (platform_info.xla_device_metadata()) {
148 options.shape_representation_fn =
149 platform_info.xla_device_metadata()->shape_representation_fn();
150 }
151 // If reference variables are not present in the graph, we can safely alias
152 // passthrough parameters without performing a copy.
153 options.alias_passthrough_params =
154 !has_ref_vars && !platform_info.is_on_xla_device();
155 return options;
156 }
157
158 } // namespace tensorflow
159