• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17 
18 #include "tensorflow/compiler/xla/client/client_library.h"
19 
20 namespace tensorflow {
21 
ParseVisibleDeviceList(absl::string_view visible_device_list)22 xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
23     absl::string_view visible_device_list) {
24   std::set<int> gpu_ids;
25   if (visible_device_list.empty()) {
26     return {{absl::nullopt}};
27   }
28   const std::vector<string> visible_devices =
29       absl::StrSplit(visible_device_list, ',');
30   for (const string& platform_device_id_str : visible_devices) {
31     int32_t platform_device_id;
32     if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
33       return errors::InvalidArgument(
34           "Could not parse entry in 'visible_device_list': '",
35           platform_device_id_str,
36           "'. visible_device_list = ", visible_device_list);
37     }
38     gpu_ids.insert(platform_device_id);
39   }
40   return {{gpu_ids}};
41 }
42 
BuildXlaCompilationCache(DeviceBase * device,FunctionLibraryRuntime * flr,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)43 Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr,
44                                 const XlaPlatformInfo& platform_info,
45                                 XlaCompilationCache** cache) {
46   if (platform_info.xla_device_metadata()) {
47     *cache = new XlaCompilationCache(
48         platform_info.xla_device_metadata()->client(),
49         platform_info.xla_device_metadata()->jit_device_type());
50     return Status::OK();
51   }
52 
53   auto platform =
54       se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
55   if (!platform.ok()) {
56     return platform.status();
57   }
58 
59   StatusOr<xla::Compiler*> compiler_for_platform =
60       xla::Compiler::GetForPlatform(platform.ValueOrDie());
61   if (!compiler_for_platform.ok()) {
62     // In some rare cases (usually in unit tests with very small clusters) we
63     // may end up transforming an XLA cluster with at least one GPU operation
64     // (which would normally force the cluster to be compiled using XLA:GPU)
65     // into an XLA cluster with no GPU operations (i.e. containing only CPU
66     // operations).  Such a cluster can fail compilation (in way that
67     // MarkForCompilation could not have detected) if the CPU JIT is not linked
68     // in.
69     //
70     // So bail out of _XlaCompile in this case, and let the executor handle the
71     // situation for us.
72     const Status& status = compiler_for_platform.status();
73     if (status.code() == error::NOT_FOUND) {
74       return errors::Unimplemented("Could not find compiler for platform ",
75                                    platform.ValueOrDie()->Name(), ": ",
76                                    status.ToString());
77     }
78   }
79 
80   xla::LocalClientOptions client_options;
81   client_options.set_platform(platform.ValueOrDie());
82   client_options.set_intra_op_parallelism_threads(
83       device->tensorflow_cpu_worker_threads()->num_threads);
84 
85   string allowed_gpus =
86       flr->config_proto()->gpu_options().visible_device_list();
87   TF_ASSIGN_OR_RETURN(absl::optional<std::set<int>> gpu_ids,
88                       ParseVisibleDeviceList(allowed_gpus));
89   client_options.set_allowed_devices(gpu_ids);
90 
91   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
92   if (!client.ok()) {
93     return client.status();
94   }
95   const XlaOpRegistry::DeviceRegistration* registration;
96   if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
97                                            &registration)) {
98     return errors::InvalidArgument("No JIT device registered for ",
99                                    platform_info.device_type().type());
100   }
101   *cache = new XlaCompilationCache(
102       client.ValueOrDie(), DeviceType(registration->compilation_device_name));
103   return Status::OK();
104 }
105 
XlaPlatformInfoFromDevice(DeviceBase * device_base)106 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
107   auto device = static_cast<Device*>(device_base);
108   se::Platform::Id platform_id = nullptr;
109   const XlaDevice::Metadata* xla_device_metadata = nullptr;
110   std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
111 
112   if (device->device_type() == DEVICE_CPU) {
113     platform_id = se::host::kHostPlatformId;
114   } else if (device->device_type() == DEVICE_GPU) {
115     platform_id = device->tensorflow_gpu_device_info()
116                       ->stream->parent()
117                       ->platform()
118                       ->id();
119   } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
120                  .ok()) {
121     // If we are on an XlaDevice, use the underlying XLA platform's allocator
122     // directly. We could use the StreamExecutor's allocator which may
123     // theoretically be more correct, but XLA returns a nice OOM message in a
124     // Status and StreamExecutor does not.
125     //
126     // Importantly we can't use ctx->device()->GetAllocator() as the allocator
127     // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
128     // allocator that returns XlaTensor objects. The XlaCompiler needs a real
129     // allocator to allocate real buffers.
130     platform_id = xla_device_metadata->platform()->id();
131     custom_allocator =
132         xla_device_metadata->client()->backend().shared_memory_allocator();
133   }
134 
135   return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
136                          xla_device_metadata, custom_allocator);
137 }
138 
GetAllocator(DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)139 std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
140     DeviceBase* device, se::Stream* stream,
141     const XlaPlatformInfo& platform_info) {
142   if (platform_info.custom_allocator()) {
143     return platform_info.custom_allocator();
144   }
145   auto* alloc = device->GetAllocator({});
146   if (!stream) {
147     // Stream is not set for the host platform.
148     se::Platform* platform =
149         se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
150             .ValueOrDie();
151     return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
152   }
153   return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
154 }
155 
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars)156 XlaCompiler::Options GenerateCompilerOptions(
157     const XlaCompilationCache& cache,
158     const FunctionLibraryRuntime& function_library, DeviceBase* device,
159     se::Stream* stream, const XlaPlatformInfo& platform_info,
160     bool has_ref_vars) {
161   XlaCompiler::Options options;
162   options.client = static_cast<xla::LocalClient*>(cache.client());
163   if (stream != nullptr) {
164     options.device_ordinal = stream->parent()->device_ordinal();
165   }
166   options.device_type = cache.device_type();
167   options.flib_def = function_library.GetFunctionLibraryDefinition();
168   options.graph_def_version = function_library.graph_def_version();
169   options.allow_cpu_custom_calls =
170       (platform_info.platform_id() == se::host::kHostPlatformId);
171   options.device_allocator = GetAllocator(device, stream, platform_info);
172   if (platform_info.xla_device_metadata()) {
173     options.shape_representation_fn =
174         platform_info.xla_device_metadata()->shape_representation_fn();
175   }
176   // If reference variables are not present in the graph, we can safely alias
177   // passthrough parameters without performing a copy.
178   options.alias_passthrough_params =
179       !has_ref_vars && !platform_info.is_on_xla_device();
180   return options;
181 }
182 
183 }  // namespace tensorflow
184