• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/local_client.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 
30 using xla::source_map_util::InvalidParameterArgument;
31 
32 namespace xla {
33 
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36                                                 Backend* backend) {
37   if (device_ordinal < 0) {
38     device_ordinal = backend->default_device_ordinal();
39   }
40   return backend->BorrowStream(device_ordinal);
41 }
42 }  // namespace
43 
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45                                  Backend* backend,
46                                  ExecutableBuildOptions build_options)
47     : executable_(std::move(executable)),
48       backend_(backend),
49       build_options_(std::move(build_options)) {
50   CHECK_GE(build_options_.device_ordinal(), 0)
51       << "Must have a valid device ordinal that the executable was built for.";
52 }
53 
ValidateExecutionOptions(const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55     const ExecutableRunOptions& run_options, const Backend& backend) {
56   if (run_options.stream() != nullptr) {
57     if (!run_options.stream()->ok()) {
58       return InvalidArgument("stream is uninitialized or in an error state");
59     }
60 
61     // Check stream matches service platform.
62     const se::Platform* stream_platform =
63         run_options.stream()->parent()->platform();
64     if (stream_platform != backend_->platform()) {
65       return InvalidArgument(
66           "stream is for platform %s, but service targets platform %s",
67           stream_platform->Name(), backend_->platform()->Name());
68     }
69 
70     // Cannot specify device_ordinal with a stream. The stream determines these
71     // values.
72     if (run_options.device_ordinal() != -1) {
73       return InvalidArgument(
74           "cannot set both device ordinal and stream options in "
75           "ExecutableRunOptions; the stream determines the device ordinal");
76     }
77   }
78 
79   // Verify that the device the executable was built for is equivalent
80   // to the device it will run on.
81   int run_device_ordinal = run_options.device_ordinal();
82   if (run_device_ordinal == -1) {
83     run_device_ordinal = run_options.stream() != nullptr
84                              ? run_options.stream()->parent()->device_ordinal()
85                              : backend_->default_device_ordinal();
86   }
87   TF_ASSIGN_OR_RETURN(bool devices_equivalent,
88                       backend_->devices_equivalent(
89                           run_device_ordinal, build_options_.device_ordinal()));
90   if (!devices_equivalent) {
91     TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
92                         backend_->stream_executor(run_device_ordinal));
93     TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
94                         backend_->stream_executor(build_device_ordinal()));
95     return InvalidArgument(
96         "executable is built for device %s of type \"%s\"; cannot run it on "
97         "device %s of type \"%s\"",
98         backend_->device_name(build_device_ordinal()),
99         build_executor->GetDeviceDescription().name(),
100         backend_->device_name(run_device_ordinal),
101         run_executor->GetDeviceDescription().name());
102   }
103 
104   if (!run_options.allocator()) {
105     return InvalidArgument("an allocator must be provided to ExecuteLocally");
106   }
107 
108   if (run_options.allocator()->platform() != backend.platform()) {
109     return InvalidArgument(
110         "allocator platform (%s) does not match service platform (%s)",
111         run_options.allocator()->platform()->Name(),
112         backend.platform()->Name());
113   }
114 
115   return Status::OK();
116 }
117 
118 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
RunHelper(const absl::Span<const Shape * const> argument_shapes,ExecutableRunOptions run_options)119 LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
120                            ExecutableRunOptions run_options) {
121   const ComputationLayout& computation_layout =
122       executable_->module_config().entry_computation_layout();
123 
124   // Check argument number, shapes, and layouts.
125   if (argument_shapes.size() != computation_layout.parameter_count()) {
126     return InvalidArgument(
127         "invalid number of arguments for computation: expected %d, got %u",
128         computation_layout.parameter_count(), argument_shapes.size());
129   }
130   for (int i = 0; i < argument_shapes.size(); ++i) {
131     if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
132             *argument_shapes[i])) {
133       return InvalidParameterArgument(
134           executable_.get(), i,
135           "Argument does not match host shape or layout of computation "
136           "parameter "
137           "%d: want %s, got %s",
138           i,
139           ShapeUtil::HumanStringWithLayout(
140               computation_layout.parameter_layout(i).shape()),
141           ShapeUtil::HumanStringWithLayout(*argument_shapes[i]));
142     }
143   }
144 
145   TF_RETURN_IF_ERROR(ValidateExecutionOptions(run_options, *backend_));
146 
147   StreamPool::Ptr stream;
148   if (run_options.stream() == nullptr) {
149     // NB!  The lifetime of `stream` needs to match the lifetime of
150     // `service_options` (otherwise we will end up using a returned stream in
151     // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
152     // scope.
153     TF_ASSIGN_OR_RETURN(
154         stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
155     run_options.set_stream(stream.get());
156   }
157   if (run_options.allocator() == nullptr) {
158     run_options.set_allocator(backend_->memory_allocator());
159   }
160 
161   // For local client execution on CPU backends:
162   // *) The thread pool used for eigen CPU ops is from
163   //    ExecutableRunOptions.eigen_intra_op_thread_pool.
164   // *) The thread pool used for XLA CPU ops is from
165   //    backend_->eigen_intra_op_thread_pool().
166   ServiceExecutableRunOptions service_options(run_options,
167                                               backend_->StreamBorrower());
168   return std::make_pair(service_options, std::move(stream));
169 }
170 
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)171 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
172     const absl::Span<const ShapedBuffer* const> arguments,
173     ExecutableRunOptions run_options) {
174   std::vector<const Shape*> argument_shapes;
175   argument_shapes.reserve(arguments.size());
176   for (const ShapedBuffer* const arg : arguments) {
177     argument_shapes.push_back(&arg->on_host_shape());
178   }
179   TF_ASSIGN_OR_RETURN(auto options_and_stream,
180                       RunHelper(argument_shapes, run_options));
181   ExecutableRunOptions options = options_and_stream.first.run_options();
182   options.set_device_ordinal(-1);
183   auto result = RunAsync(arguments, options);
184   Status block_status = options.stream()->BlockHostUntilDone();
185   TF_RETURN_IF_ERROR(result.status());
186   TF_RETURN_IF_ERROR(block_status);
187   return result;
188 }
189 
DumpArguments(const Backend * backend,const Executable * executable,const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream)190 static std::shared_ptr<HloSnapshot> DumpArguments(
191     const Backend* backend, const Executable* executable,
192     const absl::Span<const ShapedBuffer* const> arguments, se::Stream* stream) {
193   auto snapshot = std::make_shared<HloSnapshot>();
194   snapshot->set_execution_platform(backend->platform()->Name());
195   *snapshot->mutable_hlo() = *executable->hlo_proto();
196   for (const ShapedBuffer* arg : arguments) {
197     auto literal = std::make_shared<Literal>(arg->on_host_shape());
198     backend->transfer_manager()->TransferLiteralFromDevice(
199         stream, *arg, literal.get(), [snapshot, literal](Status status) {
200           if (!status.ok()) {
201             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
202                           "failed: "
203                        << status;
204             return;
205           }
206           *snapshot->add_arguments() = literal->ToProto();
207         });
208   }
209   return snapshot;
210 }
211 
DumpOutputsAndSaveSnapshot(const Backend * backend,const ShapedBuffer & outputs,std::shared_ptr<HloSnapshot> snapshot,se::Stream * stream)212 static void DumpOutputsAndSaveSnapshot(const Backend* backend,
213                                        const ShapedBuffer& outputs,
214                                        std::shared_ptr<HloSnapshot> snapshot,
215                                        se::Stream* stream) {
216   auto literal = std::make_shared<Literal>(outputs.on_host_shape());
217   backend->transfer_manager()->TransferLiteralFromDevice(
218       stream, outputs, literal.get(),
219       [snapshot{std::move(snapshot)}, literal](Status status) {
220         if (status.ok()) {
221           *snapshot->mutable_result() = literal->ToProto();
222         } else {
223           LOG(ERROR)
224               << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
225               << status;
226         }
227         DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
228       });
229 }
230 
RunAsync(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)231 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
232     const absl::Span<const ShapedBuffer* const> arguments,
233     ExecutableRunOptions run_options) {
234   std::vector<const Shape*> argument_shapes;
235   argument_shapes.reserve(arguments.size());
236   for (const ShapedBuffer* const arg : arguments) {
237     argument_shapes.push_back(&arg->on_host_shape());
238   }
239   TF_ASSIGN_OR_RETURN(auto options_and_stream,
240                       RunHelper(argument_shapes, run_options));
241   se::Stream* stream = run_options.stream();
242 
243   std::shared_ptr<HloSnapshot> snapshot;
244   if (executable_->dumping_snapshot()) {
245     snapshot = DumpArguments(backend_, executable_.get(), arguments, stream);
246   }
247 
248   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
249                       executable_->ExecuteAsyncOnStreamWrapper(
250                           &options_and_stream.first, arguments));
251 
252   // Transfer the outputs and save the snapshot to disk.
253   if (snapshot) {
254     DumpOutputsAndSaveSnapshot(backend_, outputs, std::move(snapshot), stream);
255   }
256 
257   return std::move(outputs);
258 }
259 
MaybeOwningShapeTreeToShapedBuffer(Shape const & on_host_shape,const ShapeTree<MaybeOwningDeviceMemory> & tree,se::Platform * platform,int device_ordinal)260 static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
261     Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree,
262     se::Platform* platform, int device_ordinal) {
263   ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal);
264   auto it = tree.begin();
265   auto out_it = result.buffers().begin();
266   for (; it != tree.end(); ++it, ++out_it) {
267     out_it->second = it->second.AsDeviceMemoryBase();
268   }
269   return result;
270 }
271 
RunAsync(absl::Span<Shape const * const> argument_host_shapes,std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,ExecutableRunOptions run_options)272 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
273     absl::Span<Shape const* const> argument_host_shapes,
274     std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
275     ExecutableRunOptions run_options) {
276   if (argument_host_shapes.size() != arguments.size()) {
277     return InvalidArgument(
278         "Number of argument host shapes not equal to number of arguments (%d "
279         "vs %d)",
280         argument_host_shapes.size(), arguments.size());
281   }
282   TF_ASSIGN_OR_RETURN(auto options_and_stream,
283                       RunHelper(argument_host_shapes, run_options));
284   se::Stream* stream = run_options.stream();
285 
286   std::shared_ptr<HloSnapshot> snapshot;
287   if (executable_->dumping_snapshot()) {
288     std::vector<ShapedBuffer> shaped_buffers;
289     std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
290     shaped_buffers.reserve(arguments.size());
291     shaped_buffer_ptrs.reserve(arguments.size());
292     for (size_t i = 0; i < arguments.size(); ++i) {
293       shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
294           *argument_host_shapes[i], arguments[i], backend_->platform(),
295           stream->parent()->device_ordinal()));
296       shaped_buffer_ptrs.push_back(&shaped_buffers.back());
297     }
298 
299     snapshot =
300         DumpArguments(backend_, executable_.get(), shaped_buffer_ptrs, stream);
301   }
302 
303   TF_ASSIGN_OR_RETURN(ExecutionOutput outputs,
304                       executable_->ExecuteAsyncOnStreamWrapper(
305                           &options_and_stream.first, std::move(arguments)));
306 
307   // Transfer the outputs and save the snapshot to disk.
308   if (snapshot) {
309     DumpOutputsAndSaveSnapshot(backend_, outputs.Result(), std::move(snapshot),
310                                stream);
311   }
312 
313   return std::move(outputs);
314 }
315 
platform() const316 se::Platform* LocalClient::platform() const {
317   return local_service_->backend().platform();
318 }
319 
device_count() const320 int LocalClient::device_count() const {
321   return local_service_->backend().device_count();
322 }
323 
device_ordinal_supported(int device_ordinal) const324 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
325   return local_service_->backend().device_ordinal_supported(device_ordinal);
326 }
327 
default_device_ordinal() const328 int LocalClient::default_device_ordinal() const {
329   return local_service_->backend().default_device_ordinal();
330 }
331 
backend() const332 const Backend& LocalClient::backend() const {
333   return local_service_->backend();
334 }
335 
mutable_backend()336 Backend* LocalClient::mutable_backend() {
337   return local_service_->mutable_backend();
338 }
339 
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)340 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
341     const XlaComputation& computation,
342     const absl::Span<const Shape* const> argument_layouts,
343     const ExecutableBuildOptions& options) {
344   ExecutableBuildOptions updated_options = options;
345   if (options.device_ordinal() == -1) {
346     updated_options.set_device_ordinal(default_device_ordinal());
347     VLOG(3) << "Set device ordinal to default value of: "
348             << updated_options.device_ordinal();
349   }
350   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
351                       local_service_->CompileExecutables(
352                           computation, argument_layouts, updated_options));
353 
354   std::vector<std::unique_ptr<LocalExecutable>> local_executables;
355   local_executables.reserve(executables.size());
356 
357   for (auto& executable : executables) {
358     local_executables.push_back(absl::make_unique<LocalExecutable>(
359         std::move(executable), local_service_->mutable_backend(),
360         updated_options));
361   }
362 
363   return std::move(local_executables);
364 }
365 
LiteralToShapedBuffer(const LiteralSlice & literal,int device_ordinal,se::DeviceMemoryAllocator * allocator)366 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
367     const LiteralSlice& literal, int device_ordinal,
368     se::DeviceMemoryAllocator* allocator) {
369   if (allocator == nullptr) {
370     allocator = backend().memory_allocator();
371   }
372   TF_ASSIGN_OR_RETURN(auto scoped_buffer,
373                       backend().transfer_manager()->AllocateScopedShapedBuffer(
374                           literal.shape(), allocator, device_ordinal));
375   TF_ASSIGN_OR_RETURN(auto stream,
376                       mutable_backend()->BorrowStream(device_ordinal));
377   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
378       stream.get(), literal, scoped_buffer));
379   return std::move(scoped_buffer);
380 }
381 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)382 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
383     const ShapedBuffer& shaped_buffer) {
384   TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
385                                        shaped_buffer.device_ordinal()));
386   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
387                                                                  shaped_buffer);
388 }
389 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)390 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
391     const GlobalDataHandle& data, int replica_number) {
392   return local_service_->GlobalDataToShapedBuffer(data, replica_number);
393 }
394 
TransferToInfeedLocal(const LiteralSlice & literal,int device_ordinal)395 Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
396                                           int device_ordinal) {
397   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
398                       backend().stream_executor(device_ordinal));
399   return backend().transfer_manager()->TransferLiteralToInfeed(executor,
400                                                                literal);
401 }
402 
TransferFromOutfeedLocal(const Shape & shape,int device_ordinal)403 StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
404                                                         int device_ordinal) {
405   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
406                       backend().stream_executor(device_ordinal));
407   auto literal = Literal::CreateFromShape(shape);
408   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
409       executor, shape, &literal));
410   return std::move(literal);
411 }
412 
ReplicaNumberToDeviceOrdinal(int replica_number)413 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
414   return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
415 }
416 
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_ordinal)417 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
418     const ::xla::BorrowingLiteral& literal, int device_ordinal) {
419   const ::xla::Shape& shape = literal.shape();
420 
421   TF_ASSIGN_OR_RETURN(::xla::ScopedShapedBuffer shaped_buffer,
422                       backend().transfer_manager()->AllocateScopedShapedBuffer(
423                           shape, backend().memory_allocator(), device_ordinal));
424   TF_ASSIGN_OR_RETURN(auto stream,
425                       mutable_backend()->BorrowStream(device_ordinal));
426   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
427       stream.get(), literal, shaped_buffer));
428   std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
429   replicated_buffer.emplace_back(std::move(shaped_buffer));
430   ::xla::TransferToServerResponse result;
431   TF_ASSIGN_OR_RETURN(*result.mutable_data(),
432                       local_service_->RegisterReplicatedBuffers(
433                           std::move(replicated_buffer),
434                           absl::StrCat("TransferToServer literal of shape ",
435                                        ::xla::ShapeUtil::HumanString(shape))));
436 
437   return result;
438 }
439 
440 }  // namespace xla
441