1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/service.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
33 #include "tensorflow/compiler/xla/service/dump.h"
34 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
35 #include "tensorflow/compiler/xla/service/executable.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
43 #include "tensorflow/compiler/xla/service/platform_util.h"
44 #include "tensorflow/compiler/xla/service/source_map_util.h"
45 #include "tensorflow/compiler/xla/service/stream_pool.h"
46 #include "tensorflow/compiler/xla/service/transfer_manager.h"
47 #include "tensorflow/compiler/xla/shape.h"
48 #include "tensorflow/compiler/xla/shape_layout.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
59 #include "tensorflow/core/platform/types.h"
60 #include "tensorflow/core/util/ptr_util.h"
61
62 namespace xla {
63 namespace {
64
65 using absl::StrCat;
66 using absl::StrFormat;
67
68 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
69 // performed on an HloModule.
70 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
71
72 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)73 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
74 se::Stream* stream, TransferManager* transfer_manager,
75 HloSnapshot* module) {
76 module->clear_arguments();
77 for (const ShapedBuffer* argument : arguments) {
78 TF_ASSIGN_OR_RETURN(
79 Literal literal,
80 transfer_manager->TransferLiteralFromDevice(stream, *argument));
81 *module->add_arguments() = literal.ToProto();
82 }
83 return Status::OK();
84 }
85
86 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)87 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
88 TransferManager* transfer_manager, HloSnapshot* module) {
89 module->clear_result();
90 TF_ASSIGN_OR_RETURN(
91 Literal literal,
92 transfer_manager->TransferLiteralFromDevice(stream, result));
93 *module->mutable_result() = literal.ToProto();
94 return Status::OK();
95 }
96
97 } // namespace
98
set_platform(se::Platform * platform)99 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
100 platform_ = platform;
101 return *this;
102 }
103
platform() const104 se::Platform* ServiceOptions::platform() const { return platform_; }
105
set_number_of_replicas(int number_of_replicas)106 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
107 number_of_replicas_ = number_of_replicas;
108 return *this;
109 }
110
number_of_replicas() const111 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
112
set_intra_op_parallelism_threads(int num_threads)113 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
114 int num_threads) {
115 intra_op_parallelism_threads_ = num_threads;
116 return *this;
117 }
118
intra_op_parallelism_threads() const119 int ServiceOptions::intra_op_parallelism_threads() const {
120 return intra_op_parallelism_threads_;
121 }
122
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)123 ServiceOptions& ServiceOptions::set_allowed_devices(
124 const absl::optional<std::set<int>>& allowed_devices) {
125 allowed_devices_ = allowed_devices;
126 return *this;
127 }
128
allowed_devices() const129 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
130 return allowed_devices_;
131 }
132
NewService(se::Platform * platform)133 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
134 se::Platform* platform) {
135 ServiceOptions default_options;
136 default_options.set_platform(platform);
137 return NewService(default_options);
138 }
139
NewService(const ServiceOptions & options)140 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
141 const ServiceOptions& options) {
142 se::Platform* platform = options.platform();
143 std::unique_ptr<Backend> execute_backend;
144 if (platform == nullptr) {
145 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
146 }
147 BackendOptions backend_options;
148 backend_options.set_platform(platform);
149 backend_options.set_allowed_devices(options.allowed_devices());
150 TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
151
152 std::unique_ptr<Service> service(
153 new Service(options, std::move(execute_backend)));
154 return std::move(service);
155 }
156
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)157 Service::Service(const ServiceOptions& options,
158 std::unique_ptr<Backend> execute_backend)
159 : options_(options),
160 allocation_tracker_(execute_backend.get()),
161 execute_backend_(std::move(execute_backend)) {
162 CHECK_GT(options_.number_of_replicas(), 0);
163 if (execute_backend_) {
164 if (execute_backend_->device_count() > 0) {
165 CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
166 << "Requested more replicas than there are devices.";
167 }
168 LOG(INFO) << StrFormat(
169 "XLA service %p executing computations on platform %s. Devices:", this,
170 execute_backend_->platform()->Name());
171 auto stream_executors = execute_backend_->stream_executors();
172 for (int i = 0; i < execute_backend_->device_count(); ++i) {
173 se::StreamExecutor* executor = stream_executors.at(i);
174 const auto& description = executor->GetDeviceDescription();
175 LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
176 description.name(),
177 description.platform_version());
178 }
179 } else {
180 VLOG(1) << "XLA compile-only service constructed";
181 }
182 }
183
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)184 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
185 CreateChannelHandleResponse* result) {
186 TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
187 channel_tracker_.NewChannel(arg->channel_type()));
188 return Status::OK();
189 }
190
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)191 Status Service::Unregister(const UnregisterRequest* arg,
192 UnregisterResponse* result) {
193 Status status;
194 for (auto& data : arg->data()) {
195 Status unregister_status = allocation_tracker_.Unregister(data);
196 if (!unregister_status.ok() && status.ok()) {
197 status = unregister_status;
198 }
199 }
200 return status;
201 }
202
203 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)204 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
205 DeconstructTupleResponse* result) {
206 TF_ASSIGN_OR_RETURN(
207 std::vector<GlobalDataHandle> elements,
208 allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
209
210 for (auto& element : elements) {
211 *result->add_element_handles() = element;
212 }
213 return Status::OK();
214 }
215
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const216 Status Service::ValidateResultShape(const Shape& client_shape,
217 const Shape& result_shape) const {
218 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
219 if (!ShapeUtil::Compatible(client_shape, result_shape)) {
220 return InvalidArgument(
221 "Shape used to set computation result layout %s is not compatible "
222 "with result shape %s",
223 ShapeUtil::HumanStringWithLayout(client_shape),
224 ShapeUtil::HumanString(result_shape));
225 }
226 return Status::OK();
227 }
228
229 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const230 Service::ResolveAndValidateArguments(
231 absl::Span<const GlobalDataHandle* const> arguments,
232 absl::Span<se::StreamExecutor* const> stream_executors) const {
233 CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
234 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
235 replicated_arguments.resize(options_.number_of_replicas());
236 for (size_t i = 0; i < arguments.size(); ++i) {
237 auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
238 if (!buffer_status.ok()) {
239 return Status(buffer_status.status().code(),
240 StrCat(buffer_status.status().error_message(), ", ",
241 "failed to resolve allocation for parameter ", i));
242 }
243 auto replicated_buffers = buffer_status.ValueOrDie();
244 CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
245 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
246 const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
247 int replica_device_ordinal = stream_executors[replica]->device_ordinal();
248 // Verify allocation is same platform and device as the execution.
249 if (shaped_buffer->platform() != execute_backend_->platform() ||
250 shaped_buffer->device_ordinal() != replica_device_ordinal) {
251 return InvalidArgument(
252 "argument %lu is on device %s:%d but computation will be executed "
253 "on device %s",
254 i, shaped_buffer->platform()->Name(),
255 shaped_buffer->device_ordinal(),
256 execute_backend_->device_name(replica_device_ordinal));
257 }
258 replicated_arguments[replica].push_back(shaped_buffer);
259 }
260 }
261 return replicated_arguments;
262 }
263
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options)264 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
265 const ProgramShape& program_shape,
266 absl::Span<const Shape* const> argument_shapes,
267 const ExecutionOptions* execution_options) {
268 auto config = absl::make_unique<HloModuleConfig>(program_shape);
269 ComputationLayout* computation_layout =
270 config->mutable_entry_computation_layout();
271 if (program_shape.parameters_size() != argument_shapes.size()) {
272 return InvalidArgument("computation takes %d parameters, but %u given",
273 program_shape.parameters_size(),
274 argument_shapes.size());
275 }
276 for (int i = 0; i < argument_shapes.size(); ++i) {
277 // Verify that shape of arguments matches the shape of the arguments in the
278 // ProgramShape.
279 if (!ShapeUtil::Compatible(*argument_shapes[i],
280 program_shape.parameters(i))) {
281 return InvalidArgument(
282 "Argument does not match shape of computation parameter %d: want "
283 "%s, got %s",
284 i, ShapeUtil::HumanString(program_shape.parameters(i)),
285 ShapeUtil::HumanString(*argument_shapes[i]));
286 }
287 TF_RETURN_IF_ERROR(
288 computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
289 *argument_shapes[i]));
290 }
291 if (execution_options != nullptr &&
292 execution_options->has_shape_with_output_layout()) {
293 const Shape shape_with_output_layout(
294 execution_options->shape_with_output_layout());
295 TF_RETURN_IF_ERROR(
296 ValidateResultShape(shape_with_output_layout, program_shape.result()));
297 TF_RETURN_IF_ERROR(
298 computation_layout->mutable_result_layout()->CopyLayoutFromShape(
299 shape_with_output_layout));
300 } else {
301 // If the result layout is not set, then choose the default.
302 computation_layout->mutable_result_layout()->SetToDefaultLayout();
303 }
304
305 if (execution_options != nullptr) {
306 if (execution_options->num_replicas() > 0) {
307 config->set_replica_count(execution_options->num_replicas());
308 } else {
309 config->set_replica_count(options_.number_of_replicas());
310 }
311 config->set_seed(execution_options->seed());
312 config->set_debug_options(execution_options->debug_options());
313 } else {
314 config->set_replica_count(options_.number_of_replicas());
315 config->set_debug_options(GetDebugOptionsFromFlags());
316 }
317
318 if (execute_backend_ != nullptr &&
319 execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
320 config->set_intra_op_parallelism_threads(
321 execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
322 }
323
324 if (execution_options != nullptr &&
325 execution_options->has_device_assignment()) {
326 TF_ASSIGN_OR_RETURN(
327 auto device_assignment,
328 DeviceAssignment::Deserialize(execution_options->device_assignment()));
329 config->set_static_device_assignment(*device_assignment);
330 }
331
332 return std::move(config);
333 }
334
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options)335 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
336 const ProgramShape& program_shape,
337 absl::Span<const ShapedBuffer* const> arguments,
338 const ExecutionOptions& execution_options) {
339 std::vector<const Shape*> argument_shapes;
340 for (const auto* arg : arguments) {
341 argument_shapes.push_back(&arg->on_host_shape());
342 }
343 return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
344 }
345
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,DeviceMemoryAllocator * device_allocator)346 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
347 const std::vector<const HloModuleProto*>& module_protos,
348 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
349 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
350 DeviceMemoryAllocator* device_allocator) {
351 VLOG(1) << StrFormat("BuildExecutable on service %p", this);
352
353 // Dump computation proto state if flag is set.
354 std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
355 for (int64 i = 0; i < module_protos.size(); ++i) {
356 auto hlo_snapshot = absl::make_unique<HloSnapshot>();
357 *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
358 hlo_snapshots.push_back(std::move(hlo_snapshot));
359 }
360
361 VLOG(1) << "Computations:";
362 for (const HloModuleProto* proto : module_protos) {
363 VLOG(1) << proto->name();
364 }
365
366 CHECK_EQ(module_protos.size(), module_configs.size());
367 auto module_group =
368 absl::make_unique<HloModuleGroup>(module_protos[0]->name());
369 for (int64 i = 0; i < module_protos.size(); ++i) {
370 const HloModuleProto* proto = module_protos[i];
371 const HloModuleConfig& config = *module_configs[i];
372 TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
373 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
374 module_group->push_back(std::move(module));
375 }
376
377 TF_ASSIGN_OR_RETURN(
378 std::vector<std::unique_ptr<Executable>> executables,
379 backend->compiler()->Compile(std::move(module_group),
380 std::move(executors), device_allocator));
381
382 for (size_t i = 0; i < module_protos.size(); ++i) {
383 const auto& debug_opts = module_configs[i]->debug_options();
384 if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
385 debug_opts.xla_dump_hlo_snapshots()) {
386 executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i]));
387 }
388 }
389
390 return std::move(executables);
391 }
392
393 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)394 Service::ExecuteParallelAndRegisterResult(
395 absl::Span<Executable* const> executables,
396 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
397 Backend* backend, absl::Span<const DeviceHandle> device_handles,
398 absl::Span<const string> result_tags, ExecutionProfile* profile) {
399 // Streams where the computation are launched, so we can wait on the streams
400 // to complete.
401 std::vector<StreamPool::Ptr> streams;
402 std::vector<std::unique_ptr<se::Timer>> timers;
403
404 // Global data handles for the computation results, one for each computation.
405 std::vector<GlobalDataHandle> result_handles;
406
407 // Device ID to stream executor, populated only with devices that are being
408 // profiled.
409 std::map<int64, se::Stream*> index_to_profiled_streams;
410
411 // Build DeviceAssignment for all cores based on the provided device handles.
412 DeviceAssignment device_assignment(options_.number_of_replicas(),
413 executables.size());
414 for (int64 i = 0; i < executables.size(); i++) {
415 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
416 CHECK_EQ(replicas.size(), arguments[i].size());
417 for (int64 replica = 0; replica < replicas.size(); ++replica) {
418 device_assignment(replica, i) = replicas[replica]->device_ordinal();
419 }
420 }
421
422 for (int64 i = 0; i < executables.size(); i++) {
423 // Stream executors for the replicas of the current computation.
424 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
425 CHECK_EQ(replicas.size(), arguments[i].size());
426 std::vector<ScopedShapedBuffer> result_buffers;
427 for (int64 replica = 0; replica < replicas.size(); ++replica) {
428 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
429 backend->BorrowStream(replicas[replica]));
430 streams.push_back(std::move(stream));
431
432 if (replica == 0 && profile != nullptr) {
433 timers.push_back(
434 absl::make_unique<se::Timer>(streams.back()->parent()));
435 streams.back()
436 ->InitTimer(timers.back().get())
437 .ThenStartTimer(timers.back().get());
438 CHECK(timers.front() != nullptr);
439 }
440
441 if (replica == 0 &&
442 executables[i]->module_config().debug_options().xla_hlo_profile() &&
443 executables[i]->hlo_profiling_enabled()) {
444 index_to_profiled_streams[i] = streams.back().get();
445 }
446
447 // Set up run options.
448 ExecutableRunOptions options;
449 options.set_stream(streams.back().get());
450 options.set_allocator(backend->memory_allocator());
451 options.set_intra_op_thread_pool(
452 backend->eigen_intra_op_thread_pool_device());
453 options.set_device_assignment(&device_assignment);
454 ServiceExecutableRunOptions run_options(options,
455 backend->StreamBorrower());
456
457 // Asynchronously launch the computation.
458 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
459 executables[i]->ExecuteAsyncOnStream(
460 &run_options, arguments[i][replica]));
461
462 if (replica == 0 && profile != nullptr) {
463 streams.back()->ThenStopTimer(timers.back().get());
464 }
465
466 result_buffers.push_back(std::move(result));
467 }
468 TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
469 allocation_tracker_.RegisterReplicatedBuffers(
470 std::move(result_buffers), result_tags[i]));
471 result_handles.push_back(handle);
472 }
473
474 // Wait for all executions to complete.
475 for (int64 i = 0; i < streams.size(); ++i) {
476 Status block_status = streams[i]->BlockHostUntilDone();
477 if (!block_status.ok()) {
478 return InternalError("failed to complete execution for stream %d: %s", i,
479 block_status.error_message());
480 }
481 }
482
483 if (profile != nullptr) {
484 CHECK(!timers.empty());
485 std::vector<uint64> timer_nanoseconds;
486 timer_nanoseconds.reserve(timers.size());
487 for (auto& timer : timers) {
488 timer_nanoseconds.push_back(timer->Nanoseconds());
489 }
490 uint64 nanoseconds =
491 *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
492
493 // Merge in run-time profile information from execution_profile on the
494 // zeroth device.
495 profile->MergeFrom(executables[0]->execution_profile());
496
497 // Overall execution time (in nanoseconds) from the executor timer.
498 profile->set_compute_and_transfer_time_ns(nanoseconds);
499
500 // TODO(b/28123297): On GPU we end up including transfer time in
501 // the compute time this way. Instead, we should get the correct
502 // value by measuring it. Setting the field here at least lets
503 // benchmarks provide *some* value for GPU computations.
504 //
505 // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
506 // the compute time without the transfer time, so this way we get the
507 // correct compute time. We should instead have the correct value for
508 // compute_and_transfer_time and set compute_time to the compute time.
509 if (profile->compute_time_ns() == 0) {
510 profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
511 }
512 }
513
514 return result_handles;
515 }
516
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)517 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
518 Executable* executable,
519 absl::Span<const std::vector<const ShapedBuffer*>> arguments,
520 Backend* backend, const DeviceHandle& device_handle,
521 const string& result_tag, ExecutionProfile* profile) {
522 // Set up streams.
523 std::vector<StreamPool::Ptr> streams;
524
525 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
526 TF_RET_CHECK(!replicas.empty());
527 for (se::StreamExecutor* executor : replicas) {
528 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
529 backend->BorrowStream(executor));
530 streams.push_back(std::move(stream));
531 }
532
533 DeviceAssignment device_assignment(options_.number_of_replicas(),
534 /*computation_count=*/1);
535 for (int64 replica = 0; replica < replicas.size(); ++replica) {
536 device_assignment(replica, 0) = replicas[replica]->device_ordinal();
537 }
538
539 // Set up run options.
540 std::vector<ServiceExecutableRunOptions> run_options;
541 for (const StreamPool::Ptr& stream : streams) {
542 ExecutableRunOptions options;
543 options.set_stream(stream.get());
544 options.set_device_ordinal(stream->parent()->device_ordinal());
545 options.set_allocator(backend->memory_allocator());
546 options.set_intra_op_thread_pool(
547 backend->eigen_intra_op_thread_pool_device());
548 options.set_device_assignment(&device_assignment);
549 run_options.emplace_back(options, backend->StreamBorrower());
550 }
551
552 if (options_.number_of_replicas() == 1) {
553 TF_ASSIGN_OR_RETURN(
554 auto result, executable->ExecuteOnStreamWrapper(&run_options[0],
555 profile, arguments[0]));
556 return allocation_tracker_.Register(std::move(result), result_tag);
557 }
558
559 // TODO(b/69985541): Support profiling also on this path.
560
561 std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
562 for (const auto& arg : arguments) {
563 replicated_arguments.push_back(arg);
564 }
565
566 TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
567 run_options, replicated_arguments));
568 TF_RET_CHECK(!results.empty());
569 return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
570 result_tag);
571 }
572
GetExecutors(const ExecutionOptions & execution_options,int64 requests_size,int64 request_index) const573 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
574 const ExecutionOptions& execution_options, int64 requests_size,
575 int64 request_index) const {
576 if (execution_options.device_handles().empty()) {
577 return FailedPrecondition(
578 "device handles must be given to execute parallel computations");
579 }
580 if (requests_size > 1 && execution_options.device_handles_size() > 1) {
581 return InvalidArgument(
582 "Parallel requests with multiple device handles is not supported. "
583 "Found %d parallel requests, with request %d containing %d device "
584 "handles.",
585 requests_size, request_index, execution_options.device_handles_size());
586 }
587 std::vector<se::StreamExecutor*> executors;
588 for (const auto& device_handle : execution_options.device_handles()) {
589 TF_ASSIGN_OR_RETURN(auto replicas,
590 Replicas(*execute_backend_, device_handle));
591 se::StreamExecutor* executor = replicas[0];
592 CHECK(executor != nullptr);
593 executors.push_back(executor);
594 }
595 return executors;
596 }
597
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const598 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
599 const ExecutionOptions& execution_options,
600 absl::Span<const GlobalDataHandle* const> arguments) const {
601 // Resolve the allocations for the arguments of the computation, and create
602 // a vector of device memory offsets for the arguments from the allocations.
603 // In the case of partitioned computations, assume all arguments go on the
604 // zeroth core.
605 TF_ASSIGN_OR_RETURN(
606 auto replicas,
607 Replicas(*execute_backend_, execution_options.device_handles(0)));
608 TF_ASSIGN_OR_RETURN(
609 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
610 ResolveAndValidateArguments(arguments, replicas));
611 return replicated_arguments;
612 }
613
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)614 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
615 ExecuteParallelResponse* result) {
616 VLOG(1) << "running execute-graph-parallel request";
617
618 std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
619 std::vector<std::vector<se::StreamExecutor*>> all_executors;
620 std::vector<const HloModuleProto*> module_protos;
621 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
622 std::vector<string> computation_names;
623 std::vector<DeviceHandle> device_handles;
624
625 int num_requested_devices =
626 std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
627 [](int a, const ExecuteGraphRequest& r) -> int {
628 return a + r.execution_options().device_handles_size();
629 });
630 if (num_requested_devices * options_.number_of_replicas() >
631 execute_backend_->device_count()) {
632 return FailedPrecondition(
633 "there are not enough stream executors to execute %d computations",
634 num_requested_devices);
635 }
636
637 for (int64 i = 0; i < arg->requests_size(); ++i) {
638 // Get the stream executor for the i'th computation. This stream executor
639 // is one of the executors to run the replicated computation.
640 const ExecutionOptions& execution_options =
641 arg->requests(i).execution_options();
642 const ExecuteGraphRequest& request = arg->requests(i);
643 TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
644 TF_RET_CHECK(request.computation().has_host_program_shape())
645 << "programe shape may not be empty";
646
647 // Get the executors.
648 TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
649 arg->requests_size(), i));
650
651 // Get the replicated arguments.
652 TF_ASSIGN_OR_RETURN(auto replicated_arguments,
653 GetArguments(execution_options, request.arguments()));
654
655 // Create an HloModuleConfig object for the computation, given the shape of
656 // the program and the argument allocations. Here, we care only about the
657 // shapes of the arguments, so, it is sufficient to use the arguments of
658 // replica 0.
659 TF_ASSIGN_OR_RETURN(
660 std::unique_ptr<HloModuleConfig> module_config,
661 CreateModuleConfig(
662 ProgramShape{request.computation().host_program_shape()},
663 replicated_arguments.front(), request.execution_options()));
664 VLOG(3)
665 << "ExecuteGraphParallel created HloModuleConfig computation layout: "
666 << module_config->entry_computation_layout().ToString();
667
668 // Adds to the vectors to build and execute the computations after the loop.
669 all_arguments.push_back(replicated_arguments);
670 all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
671 module_protos.push_back(&request.computation());
672 module_configs.push_back(std::move(module_config));
673 computation_names.insert(computation_names.end(), executors.size(),
674 request.computation().name());
675 all_executors.push_back(executors);
676 device_handles.insert(device_handles.end(),
677 execution_options.device_handles().begin(),
678 execution_options.device_handles().end());
679 }
680
681 // Build the HloModules and compile to generate the executables.
682 //
683 // TODO(jlebar): There's currently no way to pass a device allocator to
684 // ExecuteGraphParallel, so we have to pass a null device_allocator below.
685 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
686 BuildExecutables(module_protos, std::move(module_configs),
687 execute_backend_.get(), all_executors,
688 /*device_allocator=*/nullptr));
689 std::vector<Executable*> executable_ptrs;
690 executable_ptrs.reserve(executables.size());
691 for (const auto& executable : executables) {
692 executable_ptrs.push_back(executable.get());
693 }
694
695 for (int i = 0; i < executable_ptrs.size(); i++) {
696 if (executable_ptrs[i]->dumping_snapshot()) {
697 TF_ASSIGN_OR_RETURN(auto stream,
698 execute_backend_->BorrowStream(
699 all_executors[i][0]->device_ordinal()));
700 TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
701 execute_backend_->transfer_manager(),
702 executable_ptrs[i]->hlo_snapshot()));
703 }
704 }
705
706 // If we have multiple executables to run, execute them all in parallel. But
707 // if we only have one executable, execute it using the vanilla, non-parallel
708 // call.
709 //
710 // We do this because the Client API uses ExecuteGraphParallel when it wants
711 // to compile and run one computation without caching the executable, but not
712 // all backends support the async StreamExecutor API required by
713 // ExecuteParallelAndRegisterResult.
714 //
715 // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
716 // basically the same thing.
717 ExecutionProfile profile;
718 std::vector<GlobalDataHandle> outputs;
719 if (executable_ptrs.size() == 1) {
720 TF_ASSIGN_OR_RETURN(
721 auto output,
722 ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0],
723 execute_backend_.get(), device_handles[0],
724 computation_names[0], &profile));
725 outputs.push_back(std::move(output));
726 } else {
727 TF_ASSIGN_OR_RETURN(
728 outputs, ExecuteParallelAndRegisterResult(
729 executable_ptrs, all_arguments, execute_backend_.get(),
730 device_handles, computation_names, &profile));
731 }
732
733 for (const GlobalDataHandle& output : outputs) {
734 ExecuteResponse response;
735 *response.mutable_output() = output;
736 *response.mutable_profile() = profile;
737 *result->add_responses() = response;
738 }
739
740 for (int i = 0; i < executable_ptrs.size(); i++) {
741 Executable* executable = executable_ptrs[i];
742 if (executable->dumping_snapshot()) {
743 TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
744 allocation_tracker_.ResolveForReplica(outputs[i], 0));
745 TF_ASSIGN_OR_RETURN(auto stream,
746 execute_backend_->BorrowStream(all_executors[i][0]));
747 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
748 execute_backend_->transfer_manager(),
749 executable->hlo_snapshot()));
750 DumpHloSnapshotIfEnabled(executable->module(),
751 *executable->hlo_snapshot());
752 }
753 }
754
755 VLOG(1) << "successfully completed 'execute-graph-parallel' request";
756 return Status::OK();
757 }
758
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)759 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
760 GetDeviceHandlesResponse* result) {
761 const int64 available_device_count = execute_backend_->device_count();
762 const int64 replica_count = options_.number_of_replicas();
763 if (replica_count <= 0) {
764 return FailedPrecondition("Replica count must be a positive integer");
765 }
766 if (available_device_count < arg->device_count() * replica_count) {
767 return ResourceExhausted(
768 "Requested logical device count (%d) with replica count (%d) exceeds "
769 "the number of available physical devices on the target (%d)",
770 arg->device_count(), replica_count, available_device_count);
771 }
772
773 for (int64 i = 0; i < arg->device_count(); ++i) {
774 DeviceHandle device_handle;
775 device_handle.set_handle(i);
776 device_handle.set_device_count(arg->device_count());
777 *result->add_device_handles() = device_handle;
778 }
779
780 return Status::OK();
781 }
782
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,DeviceMemoryAllocator * device_allocator)783 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
784 const HloModuleProto& module_proto,
785 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
786 se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
787 VLOG(1) << StrFormat(
788 "BuildExecutable on service %p with serialized module proto: %s", this,
789 module_proto.name());
790
791 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
792 CreateModuleFromProto(module_proto, *module_config));
793 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
794
795 TF_ASSIGN_OR_RETURN(
796 module, backend->compiler()->RunHloPasses(std::move(module), executor,
797 device_allocator));
798
799 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
800 backend->compiler()->RunBackend(
801 std::move(module), executor, device_allocator));
802
803 const auto& debug_opts = module_config->debug_options();
804 if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
805 debug_opts.xla_dump_hlo_snapshots()) {
806 auto hlo_snapshot = absl::make_unique<HloSnapshot>();
807 *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
808 executable->set_hlo_snapshot(std::move(hlo_snapshot));
809 }
810
811 return std::move(executable);
812 }
813
Compile(const CompileRequest * arg,CompileResponse * result)814 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
815 VLOG(1) << "running compile request";
816 if (!arg->has_computation()) {
817 return InvalidArgument("computations may not be empty");
818 }
819 if (!arg->computation().has_host_program_shape()) {
820 return InvalidArgument("programe shape may not be empty");
821 }
822
823 if (arg->execution_options().device_handles_size() > 1) {
824 return InvalidArgument(
825 "The compile request does not support multiple device handles.");
826 }
827
828 std::vector<Shape> argument_shapes;
829 argument_shapes.reserve(arg->input_shape_with_layout_size());
830 std::vector<const Shape*> argument_shape_ptrs;
831 for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
832 argument_shapes.push_back(Shape(shape_proto));
833 argument_shape_ptrs.push_back(&argument_shapes.back());
834 }
835 TF_ASSIGN_OR_RETURN(
836 std::unique_ptr<HloModuleConfig> module_config,
837 CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
838 argument_shape_ptrs, &arg->execution_options()));
839 VLOG(3) << "Compile created HloModuleConfig computation layout: "
840 << module_config->entry_computation_layout().ToString();
841
842 TF_ASSIGN_OR_RETURN(
843 std::unique_ptr<Executable> executable,
844 BuildExecutable(arg->computation(), std::move(module_config),
845 execute_backend_.get(),
846 execute_backend_->default_stream_executor(),
847 /*device_allocator=*/nullptr));
848
849 *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
850
851 VLOG(1) << "successfully completed 'compile' request";
852 return Status::OK();
853 }
854
Execute(const ExecuteRequest * arg,ExecuteResponse * result)855 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
856 VLOG(1) << "running execute request";
857 if (!arg->has_handle()) {
858 return InvalidArgument("execution handle should not be empty");
859 }
860 TF_ASSIGN_OR_RETURN(auto executable,
861 compilation_cache_.LookUp(arg->handle()));
862
863 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
864 SingleComputationDeviceHandle()));
865 TF_ASSIGN_OR_RETURN(
866 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
867 ResolveAndValidateArguments(arg->arguments(), replicas));
868
869 // Check that the replicated_arguments has the same shape and layout as the
870 // module config used when creating the exectuable.
871 const int64 num_module_args =
872 executable->module_config().entry_computation_layout().parameter_count();
873 if (num_module_args != arg->arguments_size()) {
874 return InvalidArgument(
875 "The executable expects %lld arguments, but sees %lld.",
876 num_module_args, arg->arguments_size());
877 }
878 for (int64 i = 0; i < num_module_args; i++) {
879 const Shape& shape_module =
880 executable->module_config().entry_computation_layout().parameter_shape(
881 i);
882 const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
883 if (!ShapeUtil::Equal(shape_module, shape_arg)) {
884 return InvalidArgumentStrCat(
885 "The executable exepcts the ", i, "th argument in shape ",
886 ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
887 ShapeUtil::HumanStringWithLayout(shape_arg));
888 }
889 }
890
891 TF_ASSIGN_OR_RETURN(auto stream,
892 execute_backend_->BorrowStream(
893 execute_backend_->default_stream_executor()));
894 if (executable->dumping_snapshot()) {
895 executable->hlo_snapshot()->set_execution_platform(
896 execute_backend_->platform()->Name());
897 TF_RETURN_IF_ERROR(RecordArguments(
898 replicated_arguments.front(), stream.get(),
899 execute_backend_->transfer_manager(), executable->hlo_snapshot()));
900 }
901
902 TF_ASSIGN_OR_RETURN(
903 *result->mutable_output(),
904 ExecuteAndRegisterResult(executable.get(), replicated_arguments,
905 execute_backend_.get(),
906 SingleComputationDeviceHandle(),
907 "result of " + executable->module().name(),
908 result->mutable_profile()));
909
910 if (executable->dumping_snapshot()) {
911 TF_ASSIGN_OR_RETURN(
912 const ShapedBuffer* result_buffer,
913 allocation_tracker_.ResolveForReplica(result->output(), 0));
914 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
915 execute_backend_->transfer_manager(),
916 executable->hlo_snapshot()));
917 DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot());
918 }
919
920 VLOG(1) << "successfully completed 'execute' request";
921 return Status::OK();
922 }
923
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)924 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
925 WaitForExecutionResponse* result) {
926 TF_ASSIGN_OR_RETURN(const auto execution,
927 execution_tracker_.Resolve(arg->execution()));
928
929 TF_RETURN_IF_ERROR(execution->BlockUntilDone());
930
931 *result->mutable_output() = execution->result();
932 *result->mutable_profile() = execution->profile();
933
934 TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
935 VLOG(1) << "successfully completed 'wait-for-execution' request";
936 return Status::OK();
937 }
938
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)939 Status Service::TransferToClient(const TransferToClientRequest* arg,
940 TransferToClientResponse* result) {
941 TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
942 allocation_tracker_.ResolveForReplica(arg->data(), 0));
943
944 Shape return_shape;
945 if (arg->has_shape_with_layout()) {
946 return_shape = Shape(arg->shape_with_layout());
947 if (!LayoutUtil::HasLayout(return_shape)) {
948 return InvalidArgument("shape_with_layout must have layout if present.");
949 }
950 } else {
951 return_shape = Shape(shaped_buffer->on_host_shape());
952 }
953
954 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
955 shaped_buffer->device_ordinal()));
956
957 TF_ASSIGN_OR_RETURN(
958 Literal result_literal,
959 execute_backend_->transfer_manager()->TransferLiteralFromDevice(
960 stream.get(), *shaped_buffer));
961
962 if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
963 *result->mutable_literal() = result_literal.ToProto();
964 } else {
965 *result->mutable_literal() =
966 result_literal.Relayout(return_shape).ToProto();
967 }
968 return Status::OK();
969 }
970
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)971 Status Service::TransferToServer(const TransferToServerRequest* arg,
972 TransferToServerResponse* result) {
973 TF_ASSIGN_OR_RETURN(Literal literal,
974 Literal::CreateFromProto(arg->literal()));
975 const Shape& shape = literal.shape();
976
977 std::vector<se::StreamExecutor*> replicas;
978 if (arg->has_device_handle()) {
979 TF_ASSIGN_OR_RETURN(replicas,
980 Replicas(*execute_backend_, arg->device_handle()));
981 } else {
982 TF_ASSIGN_OR_RETURN(
983 replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
984 }
985
986 // Allocate memory in each replica and transfer the data to all replicas.
987 std::vector<ScopedShapedBuffer> replicated_buffers;
988 for (se::StreamExecutor* executor : replicas) {
989 TF_ASSIGN_OR_RETURN(
990 ScopedShapedBuffer shaped_buffer,
991 execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
992 shape, execute_backend_->memory_allocator(),
993 executor->device_ordinal()));
994 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
995 TF_RETURN_IF_ERROR(
996 execute_backend_->transfer_manager()->TransferLiteralToDevice(
997 stream.get(), literal, shaped_buffer));
998 replicated_buffers.emplace_back(std::move(shaped_buffer));
999 }
1000 TF_ASSIGN_OR_RETURN(*result->mutable_data(),
1001 allocation_tracker_.RegisterReplicatedBuffers(
1002 std::move(replicated_buffers),
1003 StrCat("TransferToServer literal of shape ",
1004 ShapeUtil::HumanString(shape))));
1005
1006 return Status::OK();
1007 }
1008
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)1009 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
1010 TransferToInfeedResponse* result) {
1011 const int64 replica_count = options_.number_of_replicas();
1012 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1013 return FailedPrecondition(
1014 "%s",
1015 StrCat("The replica_id=", arg->replica_id(),
1016 " on TransferToInfeedRequest not in range [0, replica_count=",
1017 replica_count, ")."));
1018 }
1019
1020 se::StreamExecutor* executor;
1021 if (arg->has_device_handle()) {
1022 TF_ASSIGN_OR_RETURN(auto replicas,
1023 Replicas(*execute_backend_, arg->device_handle()));
1024 executor = replicas[arg->replica_id()];
1025 } else {
1026 TF_ASSIGN_OR_RETURN(
1027 auto replicas,
1028 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1029 executor = replicas[arg->replica_id()];
1030 }
1031
1032 TF_ASSIGN_OR_RETURN(Literal literal,
1033 Literal::CreateFromProto(arg->literal()));
1034 return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1035 literal);
1036 }
1037
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1038 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1039 TransferFromOutfeedResponse* result) {
1040 const int64 replica_count = options_.number_of_replicas();
1041 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1042 return FailedPrecondition(
1043 "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1044 arg->replica_id(), replica_count);
1045 }
1046
1047 se::StreamExecutor* executor;
1048 if (arg->has_device_handle()) {
1049 TF_ASSIGN_OR_RETURN(auto replicas,
1050 Replicas(*execute_backend_, arg->device_handle()));
1051 executor = replicas[arg->replica_id()];
1052 } else {
1053 TF_ASSIGN_OR_RETURN(
1054 auto replicas,
1055 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1056 executor = replicas[arg->replica_id()];
1057 }
1058
1059 auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1060
1061 TF_RETURN_IF_ERROR(
1062 execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1063 executor, Shape(arg->shape_with_layout()), literal));
1064 *result->mutable_literal() = literal.ToProto();
1065 return Status::OK();
1066 }
1067
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1068 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1069 ResetDeviceResponse* result) {
1070 return execute_backend_->ResetDevices();
1071 }
1072
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1073 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1074 ComputeConstantResponse* result) {
1075 if (!arg->has_computation()) {
1076 return InvalidArgument("computations may not be empty");
1077 }
1078 if (!arg->computation().has_host_program_shape()) {
1079 return InvalidArgument("program shape may not be empty");
1080 }
1081 if (arg->computation().host_program_shape().parameters_size() != 0) {
1082 return InvalidArgument(
1083 "constant computation may not depend on any parameters.");
1084 }
1085
1086 ProgramShape program_shape(arg->computation().host_program_shape());
1087 TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1088 absl::optional<Layout> output_layout;
1089 if (arg->has_output_layout()) {
1090 output_layout = Layout::CreateFromProto(arg->output_layout());
1091 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1092 *output_layout, program_shape.result()));
1093 }
1094
1095 HloModuleConfig config(program_shape);
1096
1097 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1098 CreateModuleFromProto(arg->computation(), config));
1099
1100 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1101 DynamicDimensionInference::Run(module.get()));
1102
1103 HloEvaluator evaluator;
1104 evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1105 TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1106
1107 // Since the result layout is non-effective to the Evaluator results, explicit
1108 // relayout here.
1109 //
1110 // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1111 if (output_layout.has_value()) {
1112 result_literal = result_literal.Relayout(*output_layout);
1113 }
1114 *result->mutable_literal() = result_literal.ToProto();
1115
1116 return Status::OK();
1117 }
1118
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1119 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1120 TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1121 allocation_tracker_.ResolveForReplica(arg->data(), 0));
1122 *result->mutable_shape() = buffer->on_host_shape().ToProto();
1123 return Status::OK();
1124 }
1125
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1126 Status Service::GetComputationGraphStats(
1127 const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1128 if (!arg->has_computation()) {
1129 return InvalidArgument("Computations may not be empty.");
1130 }
1131 if (!arg->computation().has_host_program_shape()) {
1132 return InvalidArgument("Program shape may not be empty.");
1133 }
1134
1135 HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1136 config.set_debug_options(arg->debug_options());
1137 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1138 CreateModuleFromProto(arg->computation(), config));
1139 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1140
1141 // Run HLO analysis to get the computation statistics.
1142 HloCostAnalysis analysis(
1143 execute_backend_->compiler()->ShapeSizeBytesFunction());
1144
1145 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1146
1147 ComputationStats stats;
1148 stats.set_flop_count(analysis.flop_count());
1149 stats.set_transcendental_count(analysis.transcendental_count());
1150 *result->mutable_stats() = stats;
1151 return Status::OK();
1152 }
1153
SingleComputationDeviceHandle() const1154 DeviceHandle Service::SingleComputationDeviceHandle() const {
1155 DeviceHandle device_handle;
1156 device_handle.set_handle(0);
1157 device_handle.set_device_count(1);
1158 return device_handle;
1159 }
1160
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1161 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1162 const Backend& backend, const DeviceHandle& device_handle) const {
1163 std::vector<se::StreamExecutor*> replicas;
1164 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1165 // From the computation placer, find out the device ids of the replicas for
1166 // the given device handle.
1167 TF_ASSIGN_OR_RETURN(
1168 int device_ordinal,
1169 backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1170 options_.number_of_replicas(),
1171 device_handle.device_count()));
1172 TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1173 replicas.push_back(executor);
1174 }
1175 return replicas;
1176 }
1177
1178 } // namespace xla
1179