• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/local_client.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 
30 using xla::source_map_util::InvalidParameterArgument;
31 
32 namespace xla {
33 
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36                                                 Backend* backend) {
37   if (device_ordinal < 0) {
38     device_ordinal = backend->default_device_ordinal();
39   }
40   return backend->BorrowStream(device_ordinal);
41 }
42 }  // namespace
43 
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45                                  Backend* backend,
46                                  ExecutableBuildOptions build_options)
47     : executable_(std::move(executable)),
48       backend_(backend),
49       build_options_(std::move(build_options)) {
50   CHECK_GE(build_options_.device_ordinal(), 0)
51       << "Must have a valid device ordinal that the executable was built for.";
52 }
53 
ValidateExecutionOptions(const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55     const ExecutableRunOptions& run_options, const Backend& backend) {
56   if (run_options.stream() != nullptr) {
57     if (!run_options.stream()->ok()) {
58       return InvalidArgument("stream is uninitialized or in an error state");
59     }
60 
61     // Check stream matches service platform.
62     const se::Platform* stream_platform =
63         run_options.stream()->parent()->platform();
64     if (stream_platform != backend_->platform()) {
65       return InvalidArgument(
66           "stream is for platform %s, but service targets platform %s",
67           stream_platform->Name(), backend_->platform()->Name());
68     }
69 
70     // Cannot specify device_ordinal with a stream. The stream determines these
71     // values.
72     if (run_options.device_ordinal() != -1) {
73       return InvalidArgument(
74           "cannot set both device ordinal and stream options in "
75           "ExecutableRunOptions; the stream determines the device ordinal");
76     }
77   }
78 
79   // Verify that the device the executable was built for is equivalent
80   // to the device it will run on.
81   int run_device_ordinal = run_options.device_ordinal();
82   if (run_device_ordinal == -1) {
83     run_device_ordinal = run_options.stream() != nullptr
84                              ? run_options.stream()->parent()->device_ordinal()
85                              : backend_->default_device_ordinal();
86   }
87   TF_ASSIGN_OR_RETURN(bool devices_equivalent,
88                       backend_->devices_equivalent(
89                           run_device_ordinal, build_options_.device_ordinal()));
90   if (!devices_equivalent) {
91     TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
92                         backend_->stream_executor(run_device_ordinal));
93     TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
94                         backend_->stream_executor(build_device_ordinal()));
95     return InvalidArgument(
96         "executable is built for device %s of type \"%s\"; cannot run it on "
97         "device %s of type \"%s\"",
98         backend_->device_name(build_device_ordinal()),
99         build_executor->GetDeviceDescription().name(),
100         backend_->device_name(run_device_ordinal),
101         run_executor->GetDeviceDescription().name());
102   }
103 
104   if (!run_options.allocator()) {
105     return InvalidArgument("an allocator must be provided to ExecuteLocally");
106   }
107 
108   if (run_options.allocator()->platform() != backend.platform()) {
109     return InvalidArgument(
110         "allocator platform (%s) does not match service platform (%s)",
111         run_options.allocator()->platform()->Name(),
112         backend.platform()->Name());
113   }
114 
115   return Status::OK();
116 }
117 
118 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
RunHelper(const absl::Span<const Shape * const> argument_shapes,ExecutableRunOptions run_options)119 LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
120                            ExecutableRunOptions run_options) {
121   const ComputationLayout& computation_layout =
122       executable_->module_config().entry_computation_layout();
123 
124   // Check argument number, shapes, and layouts.
125   const int argument_shapes_size = argument_shapes.size();
126   if (argument_shapes_size != computation_layout.parameter_count()) {
127     return InvalidArgument(
128         "invalid number of arguments for computation: expected %d, got %u",
129         computation_layout.parameter_count(), argument_shapes.size());
130   }
131   for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
132     // TODO(b/187081154): Compare tiling info also.
133     if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
134             *argument_shapes[i], /*minor_to_major_only=*/false,
135             /*ignore_fully_empty_tiling=*/true)) {
136       return InvalidParameterArgument(
137           executable_.get(), i,
138           "Argument does not match host shape or layout of computation "
139           "parameter "
140           "%d: want %s, got %s",
141           i,
142           ShapeUtil::HumanStringWithLayout(
143               computation_layout.parameter_layout(i).shape()),
144           ShapeUtil::HumanStringWithLayout(*argument_shapes[i]));
145     }
146   }
147 
148   TF_RETURN_IF_ERROR(ValidateExecutionOptions(run_options, *backend_));
149 
150   StreamPool::Ptr stream;
151   if (run_options.stream() == nullptr) {
152     // NB!  The lifetime of `stream` needs to match the lifetime of
153     // `service_options` (otherwise we will end up using a returned stream in
154     // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
155     // scope.
156     TF_ASSIGN_OR_RETURN(
157         stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
158     run_options.set_stream(stream.get());
159   }
160   if (run_options.allocator() == nullptr) {
161     run_options.set_allocator(backend_->memory_allocator());
162   }
163 
164   // For local client execution on CPU backends:
165   // *) The thread pool used for eigen CPU ops is from
166   //    ExecutableRunOptions.eigen_intra_op_thread_pool.
167   // *) The thread pool used for XLA CPU ops is from
168   //    backend_->eigen_intra_op_thread_pool().
169   ServiceExecutableRunOptions service_options(run_options,
170                                               backend_->StreamBorrower());
171   return std::make_pair(service_options, std::move(stream));
172 }
173 
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)174 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
175     const absl::Span<const ShapedBuffer* const> arguments,
176     ExecutableRunOptions run_options) {
177   std::vector<const Shape*> argument_shapes;
178   argument_shapes.reserve(arguments.size());
179   for (const ShapedBuffer* const arg : arguments) {
180     argument_shapes.push_back(&arg->on_device_shape());
181   }
182   return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
183       argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
184         return RunAsync(arguments, options);
185       });
186 }
187 
Run(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)188 StatusOr<ExecutionOutput> LocalExecutable::Run(
189     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
190   std::vector<const Shape*> argument_shapes;
191   argument_shapes.reserve(arguments.size());
192   for (const ExecutionInput& arg : arguments) {
193     argument_shapes.push_back(&arg.shape());
194   }
195   return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
196       argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
197         return RunAsync(argument_shapes, std::move(arguments), options);
198       });
199 }
200 
DumpArguments(const Backend * backend,const Executable * executable,const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream)201 static std::shared_ptr<HloSnapshot> DumpArguments(
202     const Backend* backend, const Executable* executable,
203     const absl::Span<const ShapedBuffer* const> arguments, se::Stream* stream) {
204   auto snapshot = std::make_shared<HloSnapshot>();
205   snapshot->set_execution_platform(backend->platform()->Name());
206   *snapshot->mutable_hlo() = *executable->hlo_proto();
207   for (const ShapedBuffer* arg : arguments) {
208     auto literal = std::make_shared<Literal>(arg->on_host_shape());
209     backend->transfer_manager()->TransferLiteralFromDevice(
210         stream, *arg, literal.get(), [snapshot, literal](Status status) {
211           if (!status.ok()) {
212             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
213                           "failed: "
214                        << status;
215             return;
216           }
217           *snapshot->add_arguments() = literal->ToProto();
218         });
219   }
220   return snapshot;
221 }
222 
DumpOutputsAndSaveSnapshot(const Backend * backend,const ShapedBuffer & outputs,std::shared_ptr<HloSnapshot> snapshot,se::Stream * stream)223 static void DumpOutputsAndSaveSnapshot(const Backend* backend,
224                                        const ShapedBuffer& outputs,
225                                        std::shared_ptr<HloSnapshot> snapshot,
226                                        se::Stream* stream) {
227   auto literal = std::make_shared<Literal>(outputs.on_host_shape());
228   backend->transfer_manager()->TransferLiteralFromDevice(
229       stream, outputs, literal.get(),
230       [snapshot{std::move(snapshot)}, literal](Status status) {
231         if (status.ok()) {
232           *snapshot->mutable_result() = literal->ToProto();
233         } else {
234           LOG(ERROR)
235               << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
236               << status;
237         }
238         DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
239       });
240 }
241 
RunAsync(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)242 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
243     const absl::Span<const ShapedBuffer* const> arguments,
244     ExecutableRunOptions run_options) {
245   std::vector<const Shape*> argument_shapes;
246   argument_shapes.reserve(arguments.size());
247   for (const ShapedBuffer* const arg : arguments) {
248     argument_shapes.push_back(&arg->on_device_shape());
249   }
250   TF_ASSIGN_OR_RETURN(auto options_and_stream,
251                       RunHelper(argument_shapes, run_options));
252   se::Stream* stream = run_options.stream();
253 
254   std::shared_ptr<HloSnapshot> snapshot;
255   if (executable_->dumping_snapshot()) {
256     snapshot = DumpArguments(backend_, executable_.get(), arguments, stream);
257   }
258 
259   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
260                       executable_->ExecuteAsyncOnStreamWrapper(
261                           &options_and_stream.first, arguments));
262 
263   // Transfer the outputs and save the snapshot to disk.
264   if (snapshot) {
265     DumpOutputsAndSaveSnapshot(backend_, outputs, std::move(snapshot), stream);
266   }
267 
268   return std::move(outputs);
269 }
270 
MaybeOwningShapeTreeToShapedBuffer(const ShapeTree<MaybeOwningDeviceMemory> & tree,int device_ordinal)271 static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
272     const ShapeTree<MaybeOwningDeviceMemory>& tree, int device_ordinal) {
273   ShapedBuffer result(tree.shape(), device_ordinal);
274   auto it = tree.begin();
275   auto out_it = result.buffers().begin();
276   for (; it != tree.end(); ++it, ++out_it) {
277     out_it->second = it->second.AsDeviceMemoryBase();
278   }
279   return result;
280 }
281 
RunAsync(absl::Span<Shape const * const> argument_host_shapes,std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)282 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
283     absl::Span<Shape const* const> argument_host_shapes,
284     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
285   if (argument_host_shapes.size() != arguments.size()) {
286     return InvalidArgument(
287         "Number of argument host shapes not equal to number of arguments (%d "
288         "vs %d)",
289         argument_host_shapes.size(), arguments.size());
290   }
291   TF_ASSIGN_OR_RETURN(auto options_and_stream,
292                       RunHelper(argument_host_shapes, run_options));
293   se::Stream* stream = run_options.stream();
294 
295   std::shared_ptr<HloSnapshot> snapshot;
296   if (executable_->dumping_snapshot()) {
297     std::vector<ShapedBuffer> shaped_buffers;
298     std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
299     shaped_buffers.reserve(arguments.size());
300     shaped_buffer_ptrs.reserve(arguments.size());
301     for (size_t i = 0; i < arguments.size(); ++i) {
302       shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
303           arguments[i].Buffers(), stream->parent()->device_ordinal()));
304       shaped_buffer_ptrs.push_back(&shaped_buffers.back());
305     }
306 
307     snapshot =
308         DumpArguments(backend_, executable_.get(), shaped_buffer_ptrs, stream);
309   }
310 
311   TF_ASSIGN_OR_RETURN(ExecutionOutput outputs,
312                       executable_->ExecuteAsyncOnStreamWrapper(
313                           &options_and_stream.first, std::move(arguments)));
314 
315   // Transfer the outputs and save the snapshot to disk.
316   if (snapshot) {
317     DumpOutputsAndSaveSnapshot(backend_, outputs.Result(), std::move(snapshot),
318                                stream);
319   }
320 
321   return std::move(outputs);
322 }
323 
RunAsync(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)324 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
325     std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
326   std::vector<const Shape*> argument_shapes;
327   argument_shapes.reserve(arguments.size());
328   for (const ExecutionInput& arg : arguments) {
329     argument_shapes.push_back(&arg.shape());
330   }
331   return RunAsync(argument_shapes, std::move(arguments), run_options);
332 }
333 
platform() const334 se::Platform* LocalClient::platform() const {
335   return local_service_->backend().platform();
336 }
337 
device_count() const338 int LocalClient::device_count() const {
339   return local_service_->backend().device_count();
340 }
341 
device_ordinal_supported(int device_ordinal) const342 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
343   return local_service_->backend().device_ordinal_supported(device_ordinal);
344 }
345 
default_device_ordinal() const346 int LocalClient::default_device_ordinal() const {
347   return local_service_->backend().default_device_ordinal();
348 }
349 
backend() const350 const Backend& LocalClient::backend() const {
351   return local_service_->backend();
352 }
353 
mutable_backend()354 Backend* LocalClient::mutable_backend() {
355   return local_service_->mutable_backend();
356 }
357 
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)358 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
359     const XlaComputation& computation,
360     const absl::Span<const Shape* const> argument_layouts,
361     const ExecutableBuildOptions& options) {
362   ExecutableBuildOptions updated_options = options;
363   if (options.device_ordinal() == -1) {
364     updated_options.set_device_ordinal(default_device_ordinal());
365     VLOG(3) << "Set device ordinal to default value of: "
366             << updated_options.device_ordinal();
367   }
368   if (options.has_device_assignment()) {
369     if (options.device_assignment().replica_count() != options.num_replicas()) {
370       return InvalidArgument(
371           "Mismatched number of replicas for device "
372           "assignment and computation (%d vs %d).\n%s",
373           options.device_assignment().replica_count(), options.num_replicas(),
374           options.device_assignment().ToString());
375     }
376     if (options.device_assignment().computation_count() !=
377         options.num_partitions()) {
378       return InvalidArgument(
379           "Mismatched number of partitions for device "
380           "assignment and computation (%d vs %d).\n%s",
381           options.device_assignment().computation_count(),
382           options.num_partitions(), options.device_assignment().ToString());
383     }
384   }
385   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
386                       local_service_->CompileExecutables(
387                           computation, argument_layouts, updated_options));
388 
389   std::vector<std::unique_ptr<LocalExecutable>> local_executables;
390   local_executables.reserve(executables.size());
391 
392   for (auto& executable : executables) {
393     local_executables.push_back(absl::make_unique<LocalExecutable>(
394         std::move(executable), local_service_->mutable_backend(),
395         updated_options));
396   }
397 
398   return std::move(local_executables);
399 }
400 
LiteralToShapedBuffer(const LiteralSlice & literal,int device_ordinal,se::DeviceMemoryAllocator * allocator)401 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
402     const LiteralSlice& literal, int device_ordinal,
403     se::DeviceMemoryAllocator* allocator) {
404   if (allocator == nullptr) {
405     allocator = backend().memory_allocator();
406   }
407   TF_ASSIGN_OR_RETURN(auto scoped_buffer,
408                       backend().transfer_manager()->AllocateScopedShapedBuffer(
409                           literal.shape(), allocator, device_ordinal));
410   TF_ASSIGN_OR_RETURN(auto stream,
411                       mutable_backend()->BorrowStream(device_ordinal));
412   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
413       stream.get(), literal, scoped_buffer));
414   return std::move(scoped_buffer);
415 }
416 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)417 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
418     const ShapedBuffer& shaped_buffer) {
419   TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
420                                        shaped_buffer.device_ordinal()));
421   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
422                                                                  shaped_buffer);
423 }
424 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)425 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
426     const GlobalDataHandle& data, int replica_number) {
427   return local_service_->GlobalDataToShapedBuffer(data, replica_number);
428 }
429 
TransferToInfeedLocal(const LiteralSlice & literal,int device_ordinal)430 Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
431                                           int device_ordinal) {
432   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
433                       backend().stream_executor(device_ordinal));
434   return backend().transfer_manager()->TransferLiteralToInfeed(executor,
435                                                                literal);
436 }
437 
TransferFromOutfeedLocal(int device_ordinal,MutableBorrowingLiteral literal)438 Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
439                                              MutableBorrowingLiteral literal) {
440   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
441                       backend().stream_executor(device_ordinal));
442   return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
443                                                                   literal);
444 }
445 
ReplicaNumberToDeviceOrdinal(int replica_number)446 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
447   return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
448 }
449 
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_ordinal)450 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
451     const ::xla::BorrowingLiteral& literal, int device_ordinal) {
452   const ::xla::Shape& shape = literal.shape();
453 
454   TF_ASSIGN_OR_RETURN(::xla::ScopedShapedBuffer shaped_buffer,
455                       backend().transfer_manager()->AllocateScopedShapedBuffer(
456                           shape, backend().memory_allocator(), device_ordinal));
457   TF_ASSIGN_OR_RETURN(auto stream,
458                       mutable_backend()->BorrowStream(device_ordinal));
459   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
460       stream.get(), literal, shaped_buffer));
461   std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
462   replicated_buffer.emplace_back(std::move(shaped_buffer));
463   ::xla::TransferToServerResponse result;
464   TF_ASSIGN_OR_RETURN(*result.mutable_data(),
465                       local_service_->RegisterReplicatedBuffers(
466                           std::move(replicated_buffer),
467                           absl::StrCat("TransferToServer literal of shape ",
468                                        ::xla::ShapeUtil::HumanString(shape))));
469 
470   return result;
471 }
472 
473 }  // namespace xla
474