1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/local_client.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29
30 using xla::source_map_util::InvalidParameterArgument;
31
32 namespace xla {
33
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36 Backend* backend) {
37 if (device_ordinal < 0) {
38 device_ordinal = backend->default_device_ordinal();
39 }
40 return backend->BorrowStream(device_ordinal);
41 }
42 } // namespace
43
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45 Backend* backend,
46 ExecutableBuildOptions build_options)
47 : executable_(std::move(executable)),
48 backend_(backend),
49 build_options_(std::move(build_options)) {
50 CHECK_GE(build_options_.device_ordinal(), 0)
51 << "Must have a valid device ordinal that the executable was built for.";
52 }
53
ValidateExecutionOptions(const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55 const ExecutableRunOptions& run_options, const Backend& backend) {
56 if (run_options.stream() != nullptr) {
57 if (!run_options.stream()->ok()) {
58 return InvalidArgument("stream is uninitialized or in an error state");
59 }
60
61 // Check stream matches service platform.
62 const se::Platform* stream_platform =
63 run_options.stream()->parent()->platform();
64 if (stream_platform != backend_->platform()) {
65 return InvalidArgument(
66 "stream is for platform %s, but service targets platform %s",
67 stream_platform->Name(), backend_->platform()->Name());
68 }
69
70 // Cannot specify device_ordinal with a stream. The stream determines these
71 // values.
72 if (run_options.device_ordinal() != -1) {
73 return InvalidArgument(
74 "cannot set both device ordinal and stream options in "
75 "ExecutableRunOptions; the stream determines the device ordinal");
76 }
77 }
78
79 // Verify that the device the executable was built for is equivalent
80 // to the device it will run on.
81 int run_device_ordinal = run_options.device_ordinal();
82 if (run_device_ordinal == -1) {
83 run_device_ordinal = run_options.stream() != nullptr
84 ? run_options.stream()->parent()->device_ordinal()
85 : backend_->default_device_ordinal();
86 }
87 TF_ASSIGN_OR_RETURN(bool devices_equivalent,
88 backend_->devices_equivalent(
89 run_device_ordinal, build_options_.device_ordinal()));
90 if (!devices_equivalent) {
91 TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
92 backend_->stream_executor(run_device_ordinal));
93 TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
94 backend_->stream_executor(build_device_ordinal()));
95 return InvalidArgument(
96 "executable is built for device %s of type \"%s\"; cannot run it on "
97 "device %s of type \"%s\"",
98 backend_->device_name(build_device_ordinal()),
99 build_executor->GetDeviceDescription().name(),
100 backend_->device_name(run_device_ordinal),
101 run_executor->GetDeviceDescription().name());
102 }
103
104 if (!run_options.allocator()) {
105 return InvalidArgument("an allocator must be provided to ExecuteLocally");
106 }
107
108 if (run_options.allocator()->platform() != backend.platform()) {
109 return InvalidArgument(
110 "allocator platform (%s) does not match service platform (%s)",
111 run_options.allocator()->platform()->Name(),
112 backend.platform()->Name());
113 }
114
115 return Status::OK();
116 }
117
118 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
RunHelper(const absl::Span<const Shape * const> argument_shapes,ExecutableRunOptions run_options)119 LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
120 ExecutableRunOptions run_options) {
121 const ComputationLayout& computation_layout =
122 executable_->module_config().entry_computation_layout();
123
124 // Check argument number, shapes, and layouts.
125 const int argument_shapes_size = argument_shapes.size();
126 if (argument_shapes_size != computation_layout.parameter_count()) {
127 return InvalidArgument(
128 "invalid number of arguments for computation: expected %d, got %u",
129 computation_layout.parameter_count(), argument_shapes.size());
130 }
131 for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
132 // TODO(b/187081154): Compare tiling info also.
133 if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
134 *argument_shapes[i], /*minor_to_major_only=*/false,
135 /*ignore_fully_empty_tiling=*/true)) {
136 return InvalidParameterArgument(
137 executable_.get(), i,
138 "Argument does not match host shape or layout of computation "
139 "parameter "
140 "%d: want %s, got %s",
141 i,
142 ShapeUtil::HumanStringWithLayout(
143 computation_layout.parameter_layout(i).shape()),
144 ShapeUtil::HumanStringWithLayout(*argument_shapes[i]));
145 }
146 }
147
148 TF_RETURN_IF_ERROR(ValidateExecutionOptions(run_options, *backend_));
149
150 StreamPool::Ptr stream;
151 if (run_options.stream() == nullptr) {
152 // NB! The lifetime of `stream` needs to match the lifetime of
153 // `service_options` (otherwise we will end up using a returned stream in
154 // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
155 // scope.
156 TF_ASSIGN_OR_RETURN(
157 stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
158 run_options.set_stream(stream.get());
159 }
160 if (run_options.allocator() == nullptr) {
161 run_options.set_allocator(backend_->memory_allocator());
162 }
163
164 // For local client execution on CPU backends:
165 // *) The thread pool used for eigen CPU ops is from
166 // ExecutableRunOptions.eigen_intra_op_thread_pool.
167 // *) The thread pool used for XLA CPU ops is from
168 // backend_->eigen_intra_op_thread_pool().
169 ServiceExecutableRunOptions service_options(run_options,
170 backend_->StreamBorrower());
171 return std::make_pair(service_options, std::move(stream));
172 }
173
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)174 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
175 const absl::Span<const ShapedBuffer* const> arguments,
176 ExecutableRunOptions run_options) {
177 std::vector<const Shape*> argument_shapes;
178 argument_shapes.reserve(arguments.size());
179 for (const ShapedBuffer* const arg : arguments) {
180 argument_shapes.push_back(&arg->on_device_shape());
181 }
182 return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
183 argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
184 return RunAsync(arguments, options);
185 });
186 }
187
Run(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)188 StatusOr<ExecutionOutput> LocalExecutable::Run(
189 std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
190 std::vector<const Shape*> argument_shapes;
191 argument_shapes.reserve(arguments.size());
192 for (const ExecutionInput& arg : arguments) {
193 argument_shapes.push_back(&arg.shape());
194 }
195 return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
196 argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
197 return RunAsync(argument_shapes, std::move(arguments), options);
198 });
199 }
200
DumpArguments(const Backend * backend,const Executable * executable,const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream)201 static std::shared_ptr<HloSnapshot> DumpArguments(
202 const Backend* backend, const Executable* executable,
203 const absl::Span<const ShapedBuffer* const> arguments, se::Stream* stream) {
204 auto snapshot = std::make_shared<HloSnapshot>();
205 snapshot->set_execution_platform(backend->platform()->Name());
206 *snapshot->mutable_hlo() = *executable->hlo_proto();
207 for (const ShapedBuffer* arg : arguments) {
208 auto literal = std::make_shared<Literal>(arg->on_host_shape());
209 backend->transfer_manager()->TransferLiteralFromDevice(
210 stream, *arg, literal.get(), [snapshot, literal](Status status) {
211 if (!status.ok()) {
212 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
213 "failed: "
214 << status;
215 return;
216 }
217 *snapshot->add_arguments() = literal->ToProto();
218 });
219 }
220 return snapshot;
221 }
222
DumpOutputsAndSaveSnapshot(const Backend * backend,const ShapedBuffer & outputs,std::shared_ptr<HloSnapshot> snapshot,se::Stream * stream)223 static void DumpOutputsAndSaveSnapshot(const Backend* backend,
224 const ShapedBuffer& outputs,
225 std::shared_ptr<HloSnapshot> snapshot,
226 se::Stream* stream) {
227 auto literal = std::make_shared<Literal>(outputs.on_host_shape());
228 backend->transfer_manager()->TransferLiteralFromDevice(
229 stream, outputs, literal.get(),
230 [snapshot{std::move(snapshot)}, literal](Status status) {
231 if (status.ok()) {
232 *snapshot->mutable_result() = literal->ToProto();
233 } else {
234 LOG(ERROR)
235 << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
236 << status;
237 }
238 DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
239 });
240 }
241
RunAsync(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)242 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
243 const absl::Span<const ShapedBuffer* const> arguments,
244 ExecutableRunOptions run_options) {
245 std::vector<const Shape*> argument_shapes;
246 argument_shapes.reserve(arguments.size());
247 for (const ShapedBuffer* const arg : arguments) {
248 argument_shapes.push_back(&arg->on_device_shape());
249 }
250 TF_ASSIGN_OR_RETURN(auto options_and_stream,
251 RunHelper(argument_shapes, run_options));
252 se::Stream* stream = run_options.stream();
253
254 std::shared_ptr<HloSnapshot> snapshot;
255 if (executable_->dumping_snapshot()) {
256 snapshot = DumpArguments(backend_, executable_.get(), arguments, stream);
257 }
258
259 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
260 executable_->ExecuteAsyncOnStreamWrapper(
261 &options_and_stream.first, arguments));
262
263 // Transfer the outputs and save the snapshot to disk.
264 if (snapshot) {
265 DumpOutputsAndSaveSnapshot(backend_, outputs, std::move(snapshot), stream);
266 }
267
268 return std::move(outputs);
269 }
270
MaybeOwningShapeTreeToShapedBuffer(const ShapeTree<MaybeOwningDeviceMemory> & tree,int device_ordinal)271 static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
272 const ShapeTree<MaybeOwningDeviceMemory>& tree, int device_ordinal) {
273 ShapedBuffer result(tree.shape(), device_ordinal);
274 auto it = tree.begin();
275 auto out_it = result.buffers().begin();
276 for (; it != tree.end(); ++it, ++out_it) {
277 out_it->second = it->second.AsDeviceMemoryBase();
278 }
279 return result;
280 }
281
RunAsync(absl::Span<Shape const * const> argument_host_shapes,std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)282 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
283 absl::Span<Shape const* const> argument_host_shapes,
284 std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
285 if (argument_host_shapes.size() != arguments.size()) {
286 return InvalidArgument(
287 "Number of argument host shapes not equal to number of arguments (%d "
288 "vs %d)",
289 argument_host_shapes.size(), arguments.size());
290 }
291 TF_ASSIGN_OR_RETURN(auto options_and_stream,
292 RunHelper(argument_host_shapes, run_options));
293 se::Stream* stream = run_options.stream();
294
295 std::shared_ptr<HloSnapshot> snapshot;
296 if (executable_->dumping_snapshot()) {
297 std::vector<ShapedBuffer> shaped_buffers;
298 std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
299 shaped_buffers.reserve(arguments.size());
300 shaped_buffer_ptrs.reserve(arguments.size());
301 for (size_t i = 0; i < arguments.size(); ++i) {
302 shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
303 arguments[i].Buffers(), stream->parent()->device_ordinal()));
304 shaped_buffer_ptrs.push_back(&shaped_buffers.back());
305 }
306
307 snapshot =
308 DumpArguments(backend_, executable_.get(), shaped_buffer_ptrs, stream);
309 }
310
311 TF_ASSIGN_OR_RETURN(ExecutionOutput outputs,
312 executable_->ExecuteAsyncOnStreamWrapper(
313 &options_and_stream.first, std::move(arguments)));
314
315 // Transfer the outputs and save the snapshot to disk.
316 if (snapshot) {
317 DumpOutputsAndSaveSnapshot(backend_, outputs.Result(), std::move(snapshot),
318 stream);
319 }
320
321 return std::move(outputs);
322 }
323
RunAsync(std::vector<ExecutionInput> arguments,ExecutableRunOptions run_options)324 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
325 std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
326 std::vector<const Shape*> argument_shapes;
327 argument_shapes.reserve(arguments.size());
328 for (const ExecutionInput& arg : arguments) {
329 argument_shapes.push_back(&arg.shape());
330 }
331 return RunAsync(argument_shapes, std::move(arguments), run_options);
332 }
333
platform() const334 se::Platform* LocalClient::platform() const {
335 return local_service_->backend().platform();
336 }
337
device_count() const338 int LocalClient::device_count() const {
339 return local_service_->backend().device_count();
340 }
341
device_ordinal_supported(int device_ordinal) const342 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
343 return local_service_->backend().device_ordinal_supported(device_ordinal);
344 }
345
default_device_ordinal() const346 int LocalClient::default_device_ordinal() const {
347 return local_service_->backend().default_device_ordinal();
348 }
349
backend() const350 const Backend& LocalClient::backend() const {
351 return local_service_->backend();
352 }
353
mutable_backend()354 Backend* LocalClient::mutable_backend() {
355 return local_service_->mutable_backend();
356 }
357
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)358 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
359 const XlaComputation& computation,
360 const absl::Span<const Shape* const> argument_layouts,
361 const ExecutableBuildOptions& options) {
362 ExecutableBuildOptions updated_options = options;
363 if (options.device_ordinal() == -1) {
364 updated_options.set_device_ordinal(default_device_ordinal());
365 VLOG(3) << "Set device ordinal to default value of: "
366 << updated_options.device_ordinal();
367 }
368 if (options.has_device_assignment()) {
369 if (options.device_assignment().replica_count() != options.num_replicas()) {
370 return InvalidArgument(
371 "Mismatched number of replicas for device "
372 "assignment and computation (%d vs %d).\n%s",
373 options.device_assignment().replica_count(), options.num_replicas(),
374 options.device_assignment().ToString());
375 }
376 if (options.device_assignment().computation_count() !=
377 options.num_partitions()) {
378 return InvalidArgument(
379 "Mismatched number of partitions for device "
380 "assignment and computation (%d vs %d).\n%s",
381 options.device_assignment().computation_count(),
382 options.num_partitions(), options.device_assignment().ToString());
383 }
384 }
385 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
386 local_service_->CompileExecutables(
387 computation, argument_layouts, updated_options));
388
389 std::vector<std::unique_ptr<LocalExecutable>> local_executables;
390 local_executables.reserve(executables.size());
391
392 for (auto& executable : executables) {
393 local_executables.push_back(absl::make_unique<LocalExecutable>(
394 std::move(executable), local_service_->mutable_backend(),
395 updated_options));
396 }
397
398 return std::move(local_executables);
399 }
400
LiteralToShapedBuffer(const LiteralSlice & literal,int device_ordinal,se::DeviceMemoryAllocator * allocator)401 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
402 const LiteralSlice& literal, int device_ordinal,
403 se::DeviceMemoryAllocator* allocator) {
404 if (allocator == nullptr) {
405 allocator = backend().memory_allocator();
406 }
407 TF_ASSIGN_OR_RETURN(auto scoped_buffer,
408 backend().transfer_manager()->AllocateScopedShapedBuffer(
409 literal.shape(), allocator, device_ordinal));
410 TF_ASSIGN_OR_RETURN(auto stream,
411 mutable_backend()->BorrowStream(device_ordinal));
412 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
413 stream.get(), literal, scoped_buffer));
414 return std::move(scoped_buffer);
415 }
416
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)417 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
418 const ShapedBuffer& shaped_buffer) {
419 TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
420 shaped_buffer.device_ordinal()));
421 return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
422 shaped_buffer);
423 }
424
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)425 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
426 const GlobalDataHandle& data, int replica_number) {
427 return local_service_->GlobalDataToShapedBuffer(data, replica_number);
428 }
429
TransferToInfeedLocal(const LiteralSlice & literal,int device_ordinal)430 Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
431 int device_ordinal) {
432 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
433 backend().stream_executor(device_ordinal));
434 return backend().transfer_manager()->TransferLiteralToInfeed(executor,
435 literal);
436 }
437
TransferFromOutfeedLocal(int device_ordinal,MutableBorrowingLiteral literal)438 Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
439 MutableBorrowingLiteral literal) {
440 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
441 backend().stream_executor(device_ordinal));
442 return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
443 literal);
444 }
445
ReplicaNumberToDeviceOrdinal(int replica_number)446 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
447 return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
448 }
449
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_ordinal)450 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
451 const ::xla::BorrowingLiteral& literal, int device_ordinal) {
452 const ::xla::Shape& shape = literal.shape();
453
454 TF_ASSIGN_OR_RETURN(::xla::ScopedShapedBuffer shaped_buffer,
455 backend().transfer_manager()->AllocateScopedShapedBuffer(
456 shape, backend().memory_allocator(), device_ordinal));
457 TF_ASSIGN_OR_RETURN(auto stream,
458 mutable_backend()->BorrowStream(device_ordinal));
459 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
460 stream.get(), literal, shaped_buffer));
461 std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
462 replicated_buffer.emplace_back(std::move(shaped_buffer));
463 ::xla::TransferToServerResponse result;
464 TF_ASSIGN_OR_RETURN(*result.mutable_data(),
465 local_service_->RegisterReplicatedBuffers(
466 std::move(replicated_buffer),
467 absl::StrCat("TransferToServer literal of shape ",
468 ::xla::ShapeUtil::HumanString(shape))));
469
470 return result;
471 }
472
473 } // namespace xla
474