• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17 
18 #include "tensorflow/compiler/xla/client/client_library.h"
19 
20 namespace tensorflow {
21 
BuildXlaCompilationCache(DeviceBase * device,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)22 Status BuildXlaCompilationCache(DeviceBase* device,
23                                 const XlaPlatformInfo& platform_info,
24                                 XlaCompilationCache** cache) {
25   if (platform_info.xla_device_metadata()) {
26     *cache = new XlaCompilationCache(
27         platform_info.xla_device_metadata()->client(),
28         platform_info.xla_device_metadata()->jit_device_type());
29     return Status::OK();
30   }
31 
32   auto platform =
33       se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
34   if (!platform.ok()) {
35     return platform.status();
36   }
37 
38   xla::StatusOr<xla::Compiler*> compiler_for_platform =
39       xla::Compiler::GetForPlatform(platform.ValueOrDie());
40   if (!compiler_for_platform.ok()) {
41     // In some rare cases (usually in unit tests with very small clusters) we
42     // may end up transforming an XLA cluster with at least one GPU operation
43     // (which would normally force the cluster to be compiled using XLA:GPU)
44     // into an XLA cluster with no GPU operations (i.e. containing only CPU
45     // operations).  Such a cluster can fail compilation (in way that
46     // MarkForCompilation could not have detected) if the CPU JIT is not linked
47     // in.
48     //
49     // So bail out of _XlaCompile in this case, and let the executor handle the
50     // situation for us.
51     const Status& status = compiler_for_platform.status();
52     if (status.code() == error::NOT_FOUND) {
53       return errors::Unimplemented("Could not find compiler for platform ",
54                                    platform.ValueOrDie()->Name(), ": ",
55                                    status.ToString());
56     }
57   }
58 
59   xla::LocalClientOptions client_options;
60   client_options.set_platform(platform.ValueOrDie());
61   client_options.set_intra_op_parallelism_threads(
62       device->tensorflow_cpu_worker_threads()->num_threads);
63   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
64   if (!client.ok()) {
65     return client.status();
66   }
67   const XlaOpRegistry::DeviceRegistration* registration;
68   if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
69                                            &registration)) {
70     return errors::InvalidArgument("No JIT device registered for ",
71                                    platform_info.device_type().type());
72   }
73   *cache = new XlaCompilationCache(
74       client.ValueOrDie(), DeviceType(registration->compilation_device_name));
75   return Status::OK();
76 }
77 
XlaPlatformInfoFromDevice(DeviceBase * device_base)78 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
79   auto device = static_cast<Device*>(device_base);
80   se::Platform::Id platform_id = nullptr;
81   const XlaDevice::Metadata* xla_device_metadata = nullptr;
82   se::DeviceMemoryAllocator* custom_allocator = nullptr;
83 
84   if (device->device_type() == DEVICE_CPU) {
85     platform_id = se::host::kHostPlatformId;
86   } else if (device->device_type() == DEVICE_GPU) {
87     platform_id = device->tensorflow_gpu_device_info()
88                       ->stream->parent()
89                       ->platform()
90                       ->id();
91   } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
92                  .ok()) {
93     // If we are on an XlaDevice, use the underlying XLA platform's allocator
94     // directly. We could use the StreamExecutor's allocator which may
95     // theoretically be more correct, but XLA returns a nice OOM message in a
96     // Status and StreamExecutor does not.
97     //
98     // Importantly we can't use ctx->device()->GetAllocator() as the allocator
99     // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
100     // allocator that returns XlaTensor objects. The XlaCompiler needs a real
101     // allocator to allocate real buffers.
102     platform_id = xla_device_metadata->platform()->id();
103     custom_allocator =
104         xla_device_metadata->client()->backend().memory_allocator();
105   }
106 
107   return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
108                          xla_device_metadata, custom_allocator);
109 }
110 
GetAllocator(absl::optional<se::TfAllocatorAdapter> * tf_allocator_adapter,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)111 se::DeviceMemoryAllocator* GetAllocator(
112     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
113     DeviceBase* device, se::Stream* stream,
114     const XlaPlatformInfo& platform_info) {
115   if (platform_info.custom_allocator()) {
116     return platform_info.custom_allocator();
117   }
118   if (!stream) {
119     // Stream is not set for the host platform.
120     se::Platform* platform =
121         se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
122             .ValueOrDie();
123     tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
124     return &tf_allocator_adapter->value();
125   }
126   tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
127   return &tf_allocator_adapter->value();
128 }
129 
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars,absl::optional<se::TfAllocatorAdapter> * tf_allocator_adapter)130 XlaCompiler::Options GenerateCompilerOptions(
131     const XlaCompilationCache& cache,
132     const FunctionLibraryRuntime& function_library, DeviceBase* device,
133     se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
134     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
135   XlaCompiler::Options options;
136   options.client = static_cast<xla::LocalClient*>(cache.client());
137   if (stream != nullptr) {
138     options.device_ordinal = stream->parent()->device_ordinal();
139   }
140   options.device_type = cache.device_type();
141   options.flib_def = function_library.GetFunctionLibraryDefinition();
142   options.graph_def_version = function_library.graph_def_version();
143   options.allow_cpu_custom_calls =
144       (platform_info.platform_id() == se::host::kHostPlatformId);
145   options.device_allocator =
146       GetAllocator(tf_allocator_adapter, device, stream, platform_info);
147   if (platform_info.xla_device_metadata()) {
148     options.shape_representation_fn =
149         platform_info.xla_device_metadata()->shape_representation_fn();
150   }
151   // If reference variables are not present in the graph, we can safely alias
152   // passthrough parameters without performing a copy.
153   options.alias_passthrough_params =
154       !has_ref_vars && !platform_info.is_on_xla_device();
155   return options;
156 }
157 
158 }  // namespace tensorflow
159