1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/local_client.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "llvm/ADT/Triple.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/source_map_util.h"
27 #include "tensorflow/compiler/xla/service/stream_pool.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29
30 using xla::source_map_util::InvalidParameterArgument;
31
32 namespace xla {
33
34 namespace {
BorrowStreamForDevice(int device_ordinal,Backend * backend)35 StatusOr<StreamPool::Ptr> BorrowStreamForDevice(int device_ordinal,
36 Backend* backend) {
37 if (device_ordinal < 0) {
38 device_ordinal = backend->default_device_ordinal();
39 }
40 return backend->BorrowStream(device_ordinal);
41 }
42 } // namespace
43
LocalExecutable(std::unique_ptr<Executable> executable,Backend * backend,ExecutableBuildOptions build_options)44 LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
45 Backend* backend,
46 ExecutableBuildOptions build_options)
47 : executable_(std::move(executable)),
48 backend_(backend),
49 build_options_(std::move(build_options)) {
50 CHECK_GE(build_options_.device_ordinal(), 0)
51 << "Must have a valid device ordinal that the executable was built for.";
52 }
53
ValidateExecutionOptions(const ExecutableRunOptions & run_options,const Backend & backend)54 Status LocalExecutable::ValidateExecutionOptions(
55 const ExecutableRunOptions& run_options, const Backend& backend) {
56 if (run_options.stream() != nullptr) {
57 if (!run_options.stream()->ok()) {
58 return InvalidArgument("stream is uninitialized or in an error state");
59 }
60
61 // Check stream matches service platform.
62 const se::Platform* stream_platform =
63 run_options.stream()->parent()->platform();
64 if (stream_platform != backend_->platform()) {
65 return InvalidArgument(
66 "stream is for platform %s, but service targets platform %s",
67 stream_platform->Name(), backend_->platform()->Name());
68 }
69
70 // Cannot specify device_ordinal with a stream. The stream determines these
71 // values.
72 if (run_options.device_ordinal() != -1) {
73 return InvalidArgument(
74 "cannot set both device ordinal and stream options in "
75 "ExecutableRunOptions; the stream determines the device ordinal");
76 }
77 }
78
79 // Verify that the device the executable was built for is equivalent
80 // to the device it will run on.
81 int run_device_ordinal = run_options.device_ordinal();
82 if (run_device_ordinal == -1) {
83 run_device_ordinal = run_options.stream() != nullptr
84 ? run_options.stream()->parent()->device_ordinal()
85 : backend_->default_device_ordinal();
86 }
87 TF_ASSIGN_OR_RETURN(bool devices_equivalent,
88 backend_->devices_equivalent(
89 run_device_ordinal, build_options_.device_ordinal()));
90 if (!devices_equivalent) {
91 TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
92 backend_->stream_executor(run_device_ordinal));
93 TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
94 backend_->stream_executor(build_device_ordinal()));
95 return InvalidArgument(
96 "executable is built for device %s of type \"%s\"; cannot run it on "
97 "device %s of type \"%s\"",
98 backend_->device_name(build_device_ordinal()),
99 build_executor->GetDeviceDescription().name(),
100 backend_->device_name(run_device_ordinal),
101 run_executor->GetDeviceDescription().name());
102 }
103
104 if (!run_options.allocator()) {
105 return InvalidArgument("an allocator must be provided to ExecuteLocally");
106 }
107
108 if (run_options.allocator()->platform() != backend.platform()) {
109 return InvalidArgument(
110 "allocator platform (%s) does not match service platform (%s)",
111 run_options.allocator()->platform()->Name(),
112 backend.platform()->Name());
113 }
114
115 return Status::OK();
116 }
117
118 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
RunHelper(const absl::Span<const Shape * const> argument_shapes,ExecutableRunOptions run_options)119 LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
120 ExecutableRunOptions run_options) {
121 const ComputationLayout& computation_layout =
122 executable_->module_config().entry_computation_layout();
123
124 // Check argument number, shapes, and layouts.
125 if (argument_shapes.size() != computation_layout.parameter_count()) {
126 return InvalidArgument(
127 "invalid number of arguments for computation: expected %d, got %u",
128 computation_layout.parameter_count(), argument_shapes.size());
129 }
130 for (int i = 0; i < argument_shapes.size(); ++i) {
131 if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
132 *argument_shapes[i])) {
133 return InvalidParameterArgument(
134 executable_.get(), i,
135 "Argument does not match host shape or layout of computation "
136 "parameter "
137 "%d: want %s, got %s",
138 i,
139 ShapeUtil::HumanStringWithLayout(
140 computation_layout.parameter_layout(i).shape()),
141 ShapeUtil::HumanStringWithLayout(*argument_shapes[i]));
142 }
143 }
144
145 TF_RETURN_IF_ERROR(ValidateExecutionOptions(run_options, *backend_));
146
147 StreamPool::Ptr stream;
148 if (run_options.stream() == nullptr) {
149 // NB! The lifetime of `stream` needs to match the lifetime of
150 // `service_options` (otherwise we will end up using a returned stream in
151 // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
152 // scope.
153 TF_ASSIGN_OR_RETURN(
154 stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
155 run_options.set_stream(stream.get());
156 }
157 if (run_options.allocator() == nullptr) {
158 run_options.set_allocator(backend_->memory_allocator());
159 }
160
161 // For local client execution on CPU backends:
162 // *) The thread pool used for eigen CPU ops is from
163 // ExecutableRunOptions.eigen_intra_op_thread_pool.
164 // *) The thread pool used for XLA CPU ops is from
165 // backend_->eigen_intra_op_thread_pool().
166 ServiceExecutableRunOptions service_options(run_options,
167 backend_->StreamBorrower());
168 return std::make_pair(service_options, std::move(stream));
169 }
170
Run(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)171 StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
172 const absl::Span<const ShapedBuffer* const> arguments,
173 ExecutableRunOptions run_options) {
174 std::vector<const Shape*> argument_shapes;
175 argument_shapes.reserve(arguments.size());
176 for (const ShapedBuffer* const arg : arguments) {
177 argument_shapes.push_back(&arg->on_host_shape());
178 }
179 TF_ASSIGN_OR_RETURN(auto options_and_stream,
180 RunHelper(argument_shapes, run_options));
181 ExecutableRunOptions options = options_and_stream.first.run_options();
182 options.set_device_ordinal(-1);
183 auto result = RunAsync(arguments, options);
184 Status block_status = options.stream()->BlockHostUntilDone();
185 TF_RETURN_IF_ERROR(result.status());
186 TF_RETURN_IF_ERROR(block_status);
187 return result;
188 }
189
DumpArguments(const Backend * backend,const Executable * executable,const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream)190 static std::shared_ptr<HloSnapshot> DumpArguments(
191 const Backend* backend, const Executable* executable,
192 const absl::Span<const ShapedBuffer* const> arguments, se::Stream* stream) {
193 auto snapshot = std::make_shared<HloSnapshot>();
194 snapshot->set_execution_platform(backend->platform()->Name());
195 *snapshot->mutable_hlo() = *executable->hlo_proto();
196 for (const ShapedBuffer* arg : arguments) {
197 auto literal = std::make_shared<Literal>(arg->on_host_shape());
198 backend->transfer_manager()->TransferLiteralFromDevice(
199 stream, *arg, literal.get(), [snapshot, literal](Status status) {
200 if (!status.ok()) {
201 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
202 "failed: "
203 << status;
204 return;
205 }
206 *snapshot->add_arguments() = literal->ToProto();
207 });
208 }
209 return snapshot;
210 }
211
DumpOutputsAndSaveSnapshot(const Backend * backend,const ShapedBuffer & outputs,std::shared_ptr<HloSnapshot> snapshot,se::Stream * stream)212 static void DumpOutputsAndSaveSnapshot(const Backend* backend,
213 const ShapedBuffer& outputs,
214 std::shared_ptr<HloSnapshot> snapshot,
215 se::Stream* stream) {
216 auto literal = std::make_shared<Literal>(outputs.on_host_shape());
217 backend->transfer_manager()->TransferLiteralFromDevice(
218 stream, outputs, literal.get(),
219 [snapshot{std::move(snapshot)}, literal](Status status) {
220 if (status.ok()) {
221 *snapshot->mutable_result() = literal->ToProto();
222 } else {
223 LOG(ERROR)
224 << "TransferLiteralFromDevice for HLO snapshot outputs failed: "
225 << status;
226 }
227 DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
228 });
229 }
230
RunAsync(const absl::Span<const ShapedBuffer * const> arguments,ExecutableRunOptions run_options)231 StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
232 const absl::Span<const ShapedBuffer* const> arguments,
233 ExecutableRunOptions run_options) {
234 std::vector<const Shape*> argument_shapes;
235 argument_shapes.reserve(arguments.size());
236 for (const ShapedBuffer* const arg : arguments) {
237 argument_shapes.push_back(&arg->on_host_shape());
238 }
239 TF_ASSIGN_OR_RETURN(auto options_and_stream,
240 RunHelper(argument_shapes, run_options));
241 se::Stream* stream = run_options.stream();
242
243 std::shared_ptr<HloSnapshot> snapshot;
244 if (executable_->dumping_snapshot()) {
245 snapshot = DumpArguments(backend_, executable_.get(), arguments, stream);
246 }
247
248 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs,
249 executable_->ExecuteAsyncOnStreamWrapper(
250 &options_and_stream.first, arguments));
251
252 // Transfer the outputs and save the snapshot to disk.
253 if (snapshot) {
254 DumpOutputsAndSaveSnapshot(backend_, outputs, std::move(snapshot), stream);
255 }
256
257 return std::move(outputs);
258 }
259
MaybeOwningShapeTreeToShapedBuffer(Shape const & on_host_shape,const ShapeTree<MaybeOwningDeviceMemory> & tree,se::Platform * platform,int device_ordinal)260 static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
261 Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree,
262 se::Platform* platform, int device_ordinal) {
263 ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal);
264 auto it = tree.begin();
265 auto out_it = result.buffers().begin();
266 for (; it != tree.end(); ++it, ++out_it) {
267 out_it->second = it->second.AsDeviceMemoryBase();
268 }
269 return result;
270 }
271
RunAsync(absl::Span<Shape const * const> argument_host_shapes,std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,ExecutableRunOptions run_options)272 StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
273 absl::Span<Shape const* const> argument_host_shapes,
274 std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
275 ExecutableRunOptions run_options) {
276 if (argument_host_shapes.size() != arguments.size()) {
277 return InvalidArgument(
278 "Number of argument host shapes not equal to number of arguments (%d "
279 "vs %d)",
280 argument_host_shapes.size(), arguments.size());
281 }
282 TF_ASSIGN_OR_RETURN(auto options_and_stream,
283 RunHelper(argument_host_shapes, run_options));
284 se::Stream* stream = run_options.stream();
285
286 std::shared_ptr<HloSnapshot> snapshot;
287 if (executable_->dumping_snapshot()) {
288 std::vector<ShapedBuffer> shaped_buffers;
289 std::vector<const ShapedBuffer*> shaped_buffer_ptrs;
290 shaped_buffers.reserve(arguments.size());
291 shaped_buffer_ptrs.reserve(arguments.size());
292 for (size_t i = 0; i < arguments.size(); ++i) {
293 shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
294 *argument_host_shapes[i], arguments[i], backend_->platform(),
295 stream->parent()->device_ordinal()));
296 shaped_buffer_ptrs.push_back(&shaped_buffers.back());
297 }
298
299 snapshot =
300 DumpArguments(backend_, executable_.get(), shaped_buffer_ptrs, stream);
301 }
302
303 TF_ASSIGN_OR_RETURN(ExecutionOutput outputs,
304 executable_->ExecuteAsyncOnStreamWrapper(
305 &options_and_stream.first, std::move(arguments)));
306
307 // Transfer the outputs and save the snapshot to disk.
308 if (snapshot) {
309 DumpOutputsAndSaveSnapshot(backend_, outputs.Result(), std::move(snapshot),
310 stream);
311 }
312
313 return std::move(outputs);
314 }
315
platform() const316 se::Platform* LocalClient::platform() const {
317 return local_service_->backend().platform();
318 }
319
device_count() const320 int LocalClient::device_count() const {
321 return local_service_->backend().device_count();
322 }
323
device_ordinal_supported(int device_ordinal) const324 bool LocalClient::device_ordinal_supported(int device_ordinal) const {
325 return local_service_->backend().device_ordinal_supported(device_ordinal);
326 }
327
default_device_ordinal() const328 int LocalClient::default_device_ordinal() const {
329 return local_service_->backend().default_device_ordinal();
330 }
331
backend() const332 const Backend& LocalClient::backend() const {
333 return local_service_->backend();
334 }
335
mutable_backend()336 Backend* LocalClient::mutable_backend() {
337 return local_service_->mutable_backend();
338 }
339
Compile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & options)340 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
341 const XlaComputation& computation,
342 const absl::Span<const Shape* const> argument_layouts,
343 const ExecutableBuildOptions& options) {
344 ExecutableBuildOptions updated_options = options;
345 if (options.device_ordinal() == -1) {
346 updated_options.set_device_ordinal(default_device_ordinal());
347 VLOG(3) << "Set device ordinal to default value of: "
348 << updated_options.device_ordinal();
349 }
350 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
351 local_service_->CompileExecutables(
352 computation, argument_layouts, updated_options));
353
354 std::vector<std::unique_ptr<LocalExecutable>> local_executables;
355 local_executables.reserve(executables.size());
356
357 for (auto& executable : executables) {
358 local_executables.push_back(absl::make_unique<LocalExecutable>(
359 std::move(executable), local_service_->mutable_backend(),
360 updated_options));
361 }
362
363 return std::move(local_executables);
364 }
365
LiteralToShapedBuffer(const LiteralSlice & literal,int device_ordinal,se::DeviceMemoryAllocator * allocator)366 StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
367 const LiteralSlice& literal, int device_ordinal,
368 se::DeviceMemoryAllocator* allocator) {
369 if (allocator == nullptr) {
370 allocator = backend().memory_allocator();
371 }
372 TF_ASSIGN_OR_RETURN(auto scoped_buffer,
373 backend().transfer_manager()->AllocateScopedShapedBuffer(
374 literal.shape(), allocator, device_ordinal));
375 TF_ASSIGN_OR_RETURN(auto stream,
376 mutable_backend()->BorrowStream(device_ordinal));
377 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
378 stream.get(), literal, scoped_buffer));
379 return std::move(scoped_buffer);
380 }
381
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)382 StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
383 const ShapedBuffer& shaped_buffer) {
384 TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
385 shaped_buffer.device_ordinal()));
386 return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
387 shaped_buffer);
388 }
389
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)390 StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
391 const GlobalDataHandle& data, int replica_number) {
392 return local_service_->GlobalDataToShapedBuffer(data, replica_number);
393 }
394
TransferToInfeedLocal(const LiteralSlice & literal,int device_ordinal)395 Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
396 int device_ordinal) {
397 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
398 backend().stream_executor(device_ordinal));
399 return backend().transfer_manager()->TransferLiteralToInfeed(executor,
400 literal);
401 }
402
TransferFromOutfeedLocal(const Shape & shape,int device_ordinal)403 StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
404 int device_ordinal) {
405 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
406 backend().stream_executor(device_ordinal));
407 auto literal = Literal::CreateFromShape(shape);
408 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
409 executor, shape, &literal));
410 return std::move(literal);
411 }
412
ReplicaNumberToDeviceOrdinal(int replica_number)413 StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
414 return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
415 }
416
TransferToLocalServer(const::xla::BorrowingLiteral & literal,int device_ordinal)417 StatusOr<TransferToServerResponse> LocalClient::TransferToLocalServer(
418 const ::xla::BorrowingLiteral& literal, int device_ordinal) {
419 const ::xla::Shape& shape = literal.shape();
420
421 TF_ASSIGN_OR_RETURN(::xla::ScopedShapedBuffer shaped_buffer,
422 backend().transfer_manager()->AllocateScopedShapedBuffer(
423 shape, backend().memory_allocator(), device_ordinal));
424 TF_ASSIGN_OR_RETURN(auto stream,
425 mutable_backend()->BorrowStream(device_ordinal));
426 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
427 stream.get(), literal, shaped_buffer));
428 std::vector<::xla::ScopedShapedBuffer> replicated_buffer;
429 replicated_buffer.emplace_back(std::move(shaped_buffer));
430 ::xla::TransferToServerResponse result;
431 TF_ASSIGN_OR_RETURN(*result.mutable_data(),
432 local_service_->RegisterReplicatedBuffers(
433 std::move(replicated_buffer),
434 absl::StrCat("TransferToServer literal of shape ",
435 ::xla::ShapeUtil::HumanString(shape))));
436
437 return result;
438 }
439
440 } // namespace xla
441