• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/service.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
33 #include "tensorflow/compiler/xla/service/dump.h"
34 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
35 #include "tensorflow/compiler/xla/service/executable.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
43 #include "tensorflow/compiler/xla/service/platform_util.h"
44 #include "tensorflow/compiler/xla/service/source_map_util.h"
45 #include "tensorflow/compiler/xla/service/stream_pool.h"
46 #include "tensorflow/compiler/xla/service/transfer_manager.h"
47 #include "tensorflow/compiler/xla/shape.h"
48 #include "tensorflow/compiler/xla/shape_layout.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
59 #include "tensorflow/core/platform/types.h"
60 #include "tensorflow/core/util/ptr_util.h"
61 
62 namespace xla {
63 namespace {
64 
65 using absl::StrCat;
66 using absl::StrFormat;
67 
68 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
69 // performed on an HloModule.
70 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
71 
72 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)73 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
74                        se::Stream* stream, TransferManager* transfer_manager,
75                        HloSnapshot* module) {
76   module->clear_arguments();
77   for (const ShapedBuffer* argument : arguments) {
78     TF_ASSIGN_OR_RETURN(
79         Literal literal,
80         transfer_manager->TransferLiteralFromDevice(stream, *argument));
81     *module->add_arguments() = literal.ToProto();
82   }
83   return Status::OK();
84 }
85 
86 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)87 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
88                     TransferManager* transfer_manager, HloSnapshot* module) {
89   module->clear_result();
90   TF_ASSIGN_OR_RETURN(
91       Literal literal,
92       transfer_manager->TransferLiteralFromDevice(stream, result));
93   *module->mutable_result() = literal.ToProto();
94   return Status::OK();
95 }
96 
97 }  // namespace
98 
set_platform(se::Platform * platform)99 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
100   platform_ = platform;
101   return *this;
102 }
103 
platform() const104 se::Platform* ServiceOptions::platform() const { return platform_; }
105 
set_number_of_replicas(int number_of_replicas)106 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
107   number_of_replicas_ = number_of_replicas;
108   return *this;
109 }
110 
number_of_replicas() const111 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
112 
set_intra_op_parallelism_threads(int num_threads)113 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
114     int num_threads) {
115   intra_op_parallelism_threads_ = num_threads;
116   return *this;
117 }
118 
intra_op_parallelism_threads() const119 int ServiceOptions::intra_op_parallelism_threads() const {
120   return intra_op_parallelism_threads_;
121 }
122 
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)123 ServiceOptions& ServiceOptions::set_allowed_devices(
124     const absl::optional<std::set<int>>& allowed_devices) {
125   allowed_devices_ = allowed_devices;
126   return *this;
127 }
128 
allowed_devices() const129 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
130   return allowed_devices_;
131 }
132 
NewService(se::Platform * platform)133 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
134     se::Platform* platform) {
135   ServiceOptions default_options;
136   default_options.set_platform(platform);
137   return NewService(default_options);
138 }
139 
NewService(const ServiceOptions & options)140 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
141     const ServiceOptions& options) {
142   se::Platform* platform = options.platform();
143   std::unique_ptr<Backend> execute_backend;
144   if (platform == nullptr) {
145     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
146   }
147   BackendOptions backend_options;
148   backend_options.set_platform(platform);
149   backend_options.set_allowed_devices(options.allowed_devices());
150   TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
151 
152   std::unique_ptr<Service> service(
153       new Service(options, std::move(execute_backend)));
154   return std::move(service);
155 }
156 
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)157 Service::Service(const ServiceOptions& options,
158                  std::unique_ptr<Backend> execute_backend)
159     : options_(options),
160       allocation_tracker_(execute_backend.get()),
161       execute_backend_(std::move(execute_backend)) {
162   CHECK_GT(options_.number_of_replicas(), 0);
163   if (execute_backend_) {
164     if (execute_backend_->device_count() > 0) {
165       CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
166           << "Requested more replicas than there are devices.";
167     }
168     LOG(INFO) << StrFormat(
169         "XLA service %p executing computations on platform %s. Devices:", this,
170         execute_backend_->platform()->Name());
171     auto stream_executors = execute_backend_->stream_executors();
172     for (int i = 0; i < execute_backend_->device_count(); ++i) {
173       se::StreamExecutor* executor = stream_executors.at(i);
174       const auto& description = executor->GetDeviceDescription();
175       LOG(INFO) << StrFormat("  StreamExecutor device (%d): %s, %s", i,
176                              description.name(),
177                              description.platform_version());
178     }
179   } else {
180     VLOG(1) << "XLA compile-only service constructed";
181   }
182 }
183 
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)184 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
185                                     CreateChannelHandleResponse* result) {
186   TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
187                       channel_tracker_.NewChannel(arg->channel_type()));
188   return Status::OK();
189 }
190 
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)191 Status Service::Unregister(const UnregisterRequest* arg,
192                            UnregisterResponse* result) {
193   Status status;
194   for (auto& data : arg->data()) {
195     Status unregister_status = allocation_tracker_.Unregister(data);
196     if (!unregister_status.ok() && status.ok()) {
197       status = unregister_status;
198     }
199   }
200   return status;
201 }
202 
203 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)204 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
205                                  DeconstructTupleResponse* result) {
206   TF_ASSIGN_OR_RETURN(
207       std::vector<GlobalDataHandle> elements,
208       allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
209 
210   for (auto& element : elements) {
211     *result->add_element_handles() = element;
212   }
213   return Status::OK();
214 }
215 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const216 Status Service::ValidateResultShape(const Shape& client_shape,
217                                     const Shape& result_shape) const {
218   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
219   if (!ShapeUtil::Compatible(client_shape, result_shape)) {
220     return InvalidArgument(
221         "Shape used to set computation result layout %s is not compatible "
222         "with result shape %s",
223         ShapeUtil::HumanStringWithLayout(client_shape),
224         ShapeUtil::HumanString(result_shape));
225   }
226   return Status::OK();
227 }
228 
229 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const230 Service::ResolveAndValidateArguments(
231     absl::Span<const GlobalDataHandle* const> arguments,
232     absl::Span<se::StreamExecutor* const> stream_executors) const {
233   CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
234   std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
235   replicated_arguments.resize(options_.number_of_replicas());
236   for (size_t i = 0; i < arguments.size(); ++i) {
237     auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
238     if (!buffer_status.ok()) {
239       return Status(buffer_status.status().code(),
240                     StrCat(buffer_status.status().error_message(), ", ",
241                            "failed to resolve allocation for parameter ", i));
242     }
243     auto replicated_buffers = buffer_status.ValueOrDie();
244     CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
245     for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
246       const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
247       int replica_device_ordinal = stream_executors[replica]->device_ordinal();
248       // Verify allocation is same platform and device as the execution.
249       if (shaped_buffer->platform() != execute_backend_->platform() ||
250           shaped_buffer->device_ordinal() != replica_device_ordinal) {
251         return InvalidArgument(
252             "argument %lu is on device %s:%d but computation will be executed "
253             "on device %s",
254             i, shaped_buffer->platform()->Name(),
255             shaped_buffer->device_ordinal(),
256             execute_backend_->device_name(replica_device_ordinal));
257       }
258       replicated_arguments[replica].push_back(shaped_buffer);
259     }
260   }
261   return replicated_arguments;
262 }
263 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options)264 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
265     const ProgramShape& program_shape,
266     absl::Span<const Shape* const> argument_shapes,
267     const ExecutionOptions* execution_options) {
268   auto config = absl::make_unique<HloModuleConfig>(program_shape);
269   ComputationLayout* computation_layout =
270       config->mutable_entry_computation_layout();
271   if (program_shape.parameters_size() != argument_shapes.size()) {
272     return InvalidArgument("computation takes %d parameters, but %u given",
273                            program_shape.parameters_size(),
274                            argument_shapes.size());
275   }
276   for (int i = 0; i < argument_shapes.size(); ++i) {
277     // Verify that shape of arguments matches the shape of the arguments in the
278     // ProgramShape.
279     if (!ShapeUtil::Compatible(*argument_shapes[i],
280                                program_shape.parameters(i))) {
281       return InvalidArgument(
282           "Argument does not match shape of computation parameter %d: want "
283           "%s, got %s",
284           i, ShapeUtil::HumanString(program_shape.parameters(i)),
285           ShapeUtil::HumanString(*argument_shapes[i]));
286     }
287     TF_RETURN_IF_ERROR(
288         computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
289             *argument_shapes[i]));
290   }
291   if (execution_options != nullptr &&
292       execution_options->has_shape_with_output_layout()) {
293     const Shape shape_with_output_layout(
294         execution_options->shape_with_output_layout());
295     TF_RETURN_IF_ERROR(
296         ValidateResultShape(shape_with_output_layout, program_shape.result()));
297     TF_RETURN_IF_ERROR(
298         computation_layout->mutable_result_layout()->CopyLayoutFromShape(
299             shape_with_output_layout));
300   } else {
301     // If the result layout is not set, then choose the default.
302     computation_layout->mutable_result_layout()->SetToDefaultLayout();
303   }
304 
305   if (execution_options != nullptr) {
306     if (execution_options->num_replicas() > 0) {
307       config->set_replica_count(execution_options->num_replicas());
308     } else {
309       config->set_replica_count(options_.number_of_replicas());
310     }
311     config->set_seed(execution_options->seed());
312     config->set_debug_options(execution_options->debug_options());
313   } else {
314     config->set_replica_count(options_.number_of_replicas());
315     config->set_debug_options(GetDebugOptionsFromFlags());
316   }
317 
318   if (execute_backend_ != nullptr &&
319       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
320     config->set_intra_op_parallelism_threads(
321         execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
322   }
323 
324   if (execution_options != nullptr &&
325       execution_options->has_device_assignment()) {
326     TF_ASSIGN_OR_RETURN(
327         auto device_assignment,
328         DeviceAssignment::Deserialize(execution_options->device_assignment()));
329     config->set_static_device_assignment(*device_assignment);
330   }
331 
332   return std::move(config);
333 }
334 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options)335 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
336     const ProgramShape& program_shape,
337     absl::Span<const ShapedBuffer* const> arguments,
338     const ExecutionOptions& execution_options) {
339   std::vector<const Shape*> argument_shapes;
340   for (const auto* arg : arguments) {
341     argument_shapes.push_back(&arg->on_host_shape());
342   }
343   return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
344 }
345 
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,DeviceMemoryAllocator * device_allocator)346 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
347     const std::vector<const HloModuleProto*>& module_protos,
348     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
349     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
350     DeviceMemoryAllocator* device_allocator) {
351   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
352 
353   // Dump computation proto state if flag is set.
354   std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
355   for (int64 i = 0; i < module_protos.size(); ++i) {
356     auto hlo_snapshot = absl::make_unique<HloSnapshot>();
357     *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
358     hlo_snapshots.push_back(std::move(hlo_snapshot));
359   }
360 
361   VLOG(1) << "Computations:";
362   for (const HloModuleProto* proto : module_protos) {
363     VLOG(1) << proto->name();
364   }
365 
366   CHECK_EQ(module_protos.size(), module_configs.size());
367   auto module_group =
368       absl::make_unique<HloModuleGroup>(module_protos[0]->name());
369   for (int64 i = 0; i < module_protos.size(); ++i) {
370     const HloModuleProto* proto = module_protos[i];
371     const HloModuleConfig& config = *module_configs[i];
372     TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
373     DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
374     module_group->push_back(std::move(module));
375   }
376 
377   TF_ASSIGN_OR_RETURN(
378       std::vector<std::unique_ptr<Executable>> executables,
379       backend->compiler()->Compile(std::move(module_group),
380                                    std::move(executors), device_allocator));
381 
382   for (size_t i = 0; i < module_protos.size(); ++i) {
383     const auto& debug_opts = module_configs[i]->debug_options();
384     if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
385         debug_opts.xla_dump_hlo_snapshots()) {
386       executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i]));
387     }
388   }
389 
390   return std::move(executables);
391 }
392 
393 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)394 Service::ExecuteParallelAndRegisterResult(
395     absl::Span<Executable* const> executables,
396     absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
397     Backend* backend, absl::Span<const DeviceHandle> device_handles,
398     absl::Span<const string> result_tags, ExecutionProfile* profile) {
399   // Streams where the computation are launched, so we can wait on the streams
400   // to complete.
401   std::vector<StreamPool::Ptr> streams;
402   std::vector<std::unique_ptr<se::Timer>> timers;
403 
404   // Global data handles for the computation results, one for each computation.
405   std::vector<GlobalDataHandle> result_handles;
406 
407   // Device ID to stream executor, populated only with devices that are being
408   // profiled.
409   std::map<int64, se::Stream*> index_to_profiled_streams;
410 
411   // Build DeviceAssignment for all cores based on the provided device handles.
412   DeviceAssignment device_assignment(options_.number_of_replicas(),
413                                      executables.size());
414   for (int64 i = 0; i < executables.size(); i++) {
415     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
416     CHECK_EQ(replicas.size(), arguments[i].size());
417     for (int64 replica = 0; replica < replicas.size(); ++replica) {
418       device_assignment(replica, i) = replicas[replica]->device_ordinal();
419     }
420   }
421 
422   for (int64 i = 0; i < executables.size(); i++) {
423     // Stream executors for the replicas of the current computation.
424     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
425     CHECK_EQ(replicas.size(), arguments[i].size());
426     std::vector<ScopedShapedBuffer> result_buffers;
427     for (int64 replica = 0; replica < replicas.size(); ++replica) {
428       TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
429                           backend->BorrowStream(replicas[replica]));
430       streams.push_back(std::move(stream));
431 
432       if (replica == 0 && profile != nullptr) {
433         timers.push_back(
434             absl::make_unique<se::Timer>(streams.back()->parent()));
435         streams.back()
436             ->InitTimer(timers.back().get())
437             .ThenStartTimer(timers.back().get());
438         CHECK(timers.front() != nullptr);
439       }
440 
441       if (replica == 0 &&
442           executables[i]->module_config().debug_options().xla_hlo_profile() &&
443           executables[i]->hlo_profiling_enabled()) {
444         index_to_profiled_streams[i] = streams.back().get();
445       }
446 
447       // Set up run options.
448       ExecutableRunOptions options;
449       options.set_stream(streams.back().get());
450       options.set_allocator(backend->memory_allocator());
451       options.set_intra_op_thread_pool(
452           backend->eigen_intra_op_thread_pool_device());
453       options.set_device_assignment(&device_assignment);
454       ServiceExecutableRunOptions run_options(options,
455                                               backend->StreamBorrower());
456 
457       // Asynchronously launch the computation.
458       TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
459                           executables[i]->ExecuteAsyncOnStream(
460                               &run_options, arguments[i][replica]));
461 
462       if (replica == 0 && profile != nullptr) {
463         streams.back()->ThenStopTimer(timers.back().get());
464       }
465 
466       result_buffers.push_back(std::move(result));
467     }
468     TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
469                         allocation_tracker_.RegisterReplicatedBuffers(
470                             std::move(result_buffers), result_tags[i]));
471     result_handles.push_back(handle);
472   }
473 
474   // Wait for all executions to complete.
475   for (int64 i = 0; i < streams.size(); ++i) {
476     Status block_status = streams[i]->BlockHostUntilDone();
477     if (!block_status.ok()) {
478       return InternalError("failed to complete execution for stream %d: %s", i,
479                            block_status.error_message());
480     }
481   }
482 
483   if (profile != nullptr) {
484     CHECK(!timers.empty());
485     std::vector<uint64> timer_nanoseconds;
486     timer_nanoseconds.reserve(timers.size());
487     for (auto& timer : timers) {
488       timer_nanoseconds.push_back(timer->Nanoseconds());
489     }
490     uint64 nanoseconds =
491         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
492 
493     // Merge in run-time profile information from execution_profile on the
494     // zeroth device.
495     profile->MergeFrom(executables[0]->execution_profile());
496 
497     // Overall execution time (in nanoseconds) from the executor timer.
498     profile->set_compute_and_transfer_time_ns(nanoseconds);
499 
500     // TODO(b/28123297): On GPU we end up including transfer time in
501     // the compute time this way. Instead, we should get the correct
502     // value by measuring it. Setting the field here at least lets
503     // benchmarks provide *some* value for GPU computations.
504     //
505     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
506     // the compute time without the transfer time, so this way we get the
507     // correct compute time. We should instead have the correct value for
508     // compute_and_transfer_time and set compute_time to the compute time.
509     if (profile->compute_time_ns() == 0) {
510       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
511     }
512   }
513 
514   return result_handles;
515 }
516 
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)517 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
518     Executable* executable,
519     absl::Span<const std::vector<const ShapedBuffer*>> arguments,
520     Backend* backend, const DeviceHandle& device_handle,
521     const string& result_tag, ExecutionProfile* profile) {
522   // Set up streams.
523   std::vector<StreamPool::Ptr> streams;
524 
525   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
526   TF_RET_CHECK(!replicas.empty());
527   for (se::StreamExecutor* executor : replicas) {
528     TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
529                         backend->BorrowStream(executor));
530     streams.push_back(std::move(stream));
531   }
532 
533   DeviceAssignment device_assignment(options_.number_of_replicas(),
534                                      /*computation_count=*/1);
535   for (int64 replica = 0; replica < replicas.size(); ++replica) {
536     device_assignment(replica, 0) = replicas[replica]->device_ordinal();
537   }
538 
539   // Set up run options.
540   std::vector<ServiceExecutableRunOptions> run_options;
541   for (const StreamPool::Ptr& stream : streams) {
542     ExecutableRunOptions options;
543     options.set_stream(stream.get());
544     options.set_device_ordinal(stream->parent()->device_ordinal());
545     options.set_allocator(backend->memory_allocator());
546     options.set_intra_op_thread_pool(
547         backend->eigen_intra_op_thread_pool_device());
548     options.set_device_assignment(&device_assignment);
549     run_options.emplace_back(options, backend->StreamBorrower());
550   }
551 
552   if (options_.number_of_replicas() == 1) {
553     TF_ASSIGN_OR_RETURN(
554         auto result, executable->ExecuteOnStreamWrapper(&run_options[0],
555                                                         profile, arguments[0]));
556     return allocation_tracker_.Register(std::move(result), result_tag);
557   }
558 
559   // TODO(b/69985541): Support profiling also on this path.
560 
561   std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
562   for (const auto& arg : arguments) {
563     replicated_arguments.push_back(arg);
564   }
565 
566   TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
567                                         run_options, replicated_arguments));
568   TF_RET_CHECK(!results.empty());
569   return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
570                                                        result_tag);
571 }
572 
GetExecutors(const ExecutionOptions & execution_options,int64 requests_size,int64 request_index) const573 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
574     const ExecutionOptions& execution_options, int64 requests_size,
575     int64 request_index) const {
576   if (execution_options.device_handles().empty()) {
577     return FailedPrecondition(
578         "device handles must be given to execute parallel computations");
579   }
580   if (requests_size > 1 && execution_options.device_handles_size() > 1) {
581     return InvalidArgument(
582         "Parallel requests with multiple device handles is not supported. "
583         "Found %d parallel requests, with request %d containing %d device "
584         "handles.",
585         requests_size, request_index, execution_options.device_handles_size());
586   }
587   std::vector<se::StreamExecutor*> executors;
588   for (const auto& device_handle : execution_options.device_handles()) {
589     TF_ASSIGN_OR_RETURN(auto replicas,
590                         Replicas(*execute_backend_, device_handle));
591     se::StreamExecutor* executor = replicas[0];
592     CHECK(executor != nullptr);
593     executors.push_back(executor);
594   }
595   return executors;
596 }
597 
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const598 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
599     const ExecutionOptions& execution_options,
600     absl::Span<const GlobalDataHandle* const> arguments) const {
601   // Resolve the allocations for the arguments of the computation, and create
602   // a vector of device memory offsets for the arguments from the allocations.
603   // In the case of partitioned computations, assume all arguments go on the
604   // zeroth core.
605   TF_ASSIGN_OR_RETURN(
606       auto replicas,
607       Replicas(*execute_backend_, execution_options.device_handles(0)));
608   TF_ASSIGN_OR_RETURN(
609       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
610       ResolveAndValidateArguments(arguments, replicas));
611   return replicated_arguments;
612 }
613 
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)614 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
615                                      ExecuteParallelResponse* result) {
616   VLOG(1) << "running execute-graph-parallel request";
617 
618   std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
619   std::vector<std::vector<se::StreamExecutor*>> all_executors;
620   std::vector<const HloModuleProto*> module_protos;
621   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
622   std::vector<string> computation_names;
623   std::vector<DeviceHandle> device_handles;
624 
625   int num_requested_devices =
626       std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
627                       [](int a, const ExecuteGraphRequest& r) -> int {
628                         return a + r.execution_options().device_handles_size();
629                       });
630   if (num_requested_devices * options_.number_of_replicas() >
631       execute_backend_->device_count()) {
632     return FailedPrecondition(
633         "there are not enough stream executors to execute %d computations",
634         num_requested_devices);
635   }
636 
637   for (int64 i = 0; i < arg->requests_size(); ++i) {
638     // Get the stream executor for the i'th computation. This stream executor
639     // is one of the executors to run the replicated computation.
640     const ExecutionOptions& execution_options =
641         arg->requests(i).execution_options();
642     const ExecuteGraphRequest& request = arg->requests(i);
643     TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
644     TF_RET_CHECK(request.computation().has_host_program_shape())
645         << "programe shape may not be empty";
646 
647     // Get the executors.
648     TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
649                                                      arg->requests_size(), i));
650 
651     // Get the replicated arguments.
652     TF_ASSIGN_OR_RETURN(auto replicated_arguments,
653                         GetArguments(execution_options, request.arguments()));
654 
655     // Create an HloModuleConfig object for the computation, given the shape of
656     // the program and the argument allocations. Here, we care only about the
657     // shapes of the arguments, so, it is sufficient to use the arguments of
658     // replica 0.
659     TF_ASSIGN_OR_RETURN(
660         std::unique_ptr<HloModuleConfig> module_config,
661         CreateModuleConfig(
662             ProgramShape{request.computation().host_program_shape()},
663             replicated_arguments.front(), request.execution_options()));
664     VLOG(3)
665         << "ExecuteGraphParallel created HloModuleConfig computation layout: "
666         << module_config->entry_computation_layout().ToString();
667 
668     // Adds to the vectors to build and execute the computations after the loop.
669     all_arguments.push_back(replicated_arguments);
670     all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
671     module_protos.push_back(&request.computation());
672     module_configs.push_back(std::move(module_config));
673     computation_names.insert(computation_names.end(), executors.size(),
674                              request.computation().name());
675     all_executors.push_back(executors);
676     device_handles.insert(device_handles.end(),
677                           execution_options.device_handles().begin(),
678                           execution_options.device_handles().end());
679   }
680 
681   // Build the HloModules and compile to generate the executables.
682   //
683   // TODO(jlebar): There's currently no way to pass a device allocator to
684   // ExecuteGraphParallel, so we have to pass a null device_allocator below.
685   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
686                       BuildExecutables(module_protos, std::move(module_configs),
687                                        execute_backend_.get(), all_executors,
688                                        /*device_allocator=*/nullptr));
689   std::vector<Executable*> executable_ptrs;
690   executable_ptrs.reserve(executables.size());
691   for (const auto& executable : executables) {
692     executable_ptrs.push_back(executable.get());
693   }
694 
695   for (int i = 0; i < executable_ptrs.size(); i++) {
696     if (executable_ptrs[i]->dumping_snapshot()) {
697       TF_ASSIGN_OR_RETURN(auto stream,
698                           execute_backend_->BorrowStream(
699                               all_executors[i][0]->device_ordinal()));
700       TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
701                                          execute_backend_->transfer_manager(),
702                                          executable_ptrs[i]->hlo_snapshot()));
703     }
704   }
705 
706   // If we have multiple executables to run, execute them all in parallel.  But
707   // if we only have one executable, execute it using the vanilla, non-parallel
708   // call.
709   //
710   // We do this because the Client API uses ExecuteGraphParallel when it wants
711   // to compile and run one computation without caching the executable, but not
712   // all backends support the async StreamExecutor API required by
713   // ExecuteParallelAndRegisterResult.
714   //
715   // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
716   // basically the same thing.
717   ExecutionProfile profile;
718   std::vector<GlobalDataHandle> outputs;
719   if (executable_ptrs.size() == 1) {
720     TF_ASSIGN_OR_RETURN(
721         auto output,
722         ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0],
723                                  execute_backend_.get(), device_handles[0],
724                                  computation_names[0], &profile));
725     outputs.push_back(std::move(output));
726   } else {
727     TF_ASSIGN_OR_RETURN(
728         outputs, ExecuteParallelAndRegisterResult(
729                      executable_ptrs, all_arguments, execute_backend_.get(),
730                      device_handles, computation_names, &profile));
731   }
732 
733   for (const GlobalDataHandle& output : outputs) {
734     ExecuteResponse response;
735     *response.mutable_output() = output;
736     *response.mutable_profile() = profile;
737     *result->add_responses() = response;
738   }
739 
740   for (int i = 0; i < executable_ptrs.size(); i++) {
741     Executable* executable = executable_ptrs[i];
742     if (executable->dumping_snapshot()) {
743       TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
744                           allocation_tracker_.ResolveForReplica(outputs[i], 0));
745       TF_ASSIGN_OR_RETURN(auto stream,
746                           execute_backend_->BorrowStream(all_executors[i][0]));
747       TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
748                                       execute_backend_->transfer_manager(),
749                                       executable->hlo_snapshot()));
750       DumpHloSnapshotIfEnabled(executable->module(),
751                                *executable->hlo_snapshot());
752     }
753   }
754 
755   VLOG(1) << "successfully completed 'execute-graph-parallel' request";
756   return Status::OK();
757 }
758 
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)759 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
760                                  GetDeviceHandlesResponse* result) {
761   const int64 available_device_count = execute_backend_->device_count();
762   const int64 replica_count = options_.number_of_replicas();
763   if (replica_count <= 0) {
764     return FailedPrecondition("Replica count must be a positive integer");
765   }
766   if (available_device_count < arg->device_count() * replica_count) {
767     return ResourceExhausted(
768         "Requested logical device count (%d) with replica count (%d) exceeds "
769         "the number of available physical devices on the target (%d)",
770         arg->device_count(), replica_count, available_device_count);
771   }
772 
773   for (int64 i = 0; i < arg->device_count(); ++i) {
774     DeviceHandle device_handle;
775     device_handle.set_handle(i);
776     device_handle.set_device_count(arg->device_count());
777     *result->add_device_handles() = device_handle;
778   }
779 
780   return Status::OK();
781 }
782 
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,DeviceMemoryAllocator * device_allocator)783 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
784     const HloModuleProto& module_proto,
785     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
786     se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
787   VLOG(1) << StrFormat(
788       "BuildExecutable on service %p with serialized module proto: %s", this,
789       module_proto.name());
790 
791   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
792                       CreateModuleFromProto(module_proto, *module_config));
793   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
794 
795   TF_ASSIGN_OR_RETURN(
796       module, backend->compiler()->RunHloPasses(std::move(module), executor,
797                                                 device_allocator));
798 
799   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
800                       backend->compiler()->RunBackend(
801                           std::move(module), executor, device_allocator));
802 
803   const auto& debug_opts = module_config->debug_options();
804   if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
805       debug_opts.xla_dump_hlo_snapshots()) {
806     auto hlo_snapshot = absl::make_unique<HloSnapshot>();
807     *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
808     executable->set_hlo_snapshot(std::move(hlo_snapshot));
809   }
810 
811   return std::move(executable);
812 }
813 
Compile(const CompileRequest * arg,CompileResponse * result)814 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
815   VLOG(1) << "running compile request";
816   if (!arg->has_computation()) {
817     return InvalidArgument("computations may not be empty");
818   }
819   if (!arg->computation().has_host_program_shape()) {
820     return InvalidArgument("programe shape may not be empty");
821   }
822 
823   if (arg->execution_options().device_handles_size() > 1) {
824     return InvalidArgument(
825         "The compile request does not support multiple device handles.");
826   }
827 
828   std::vector<Shape> argument_shapes;
829   argument_shapes.reserve(arg->input_shape_with_layout_size());
830   std::vector<const Shape*> argument_shape_ptrs;
831   for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
832     argument_shapes.push_back(Shape(shape_proto));
833     argument_shape_ptrs.push_back(&argument_shapes.back());
834   }
835   TF_ASSIGN_OR_RETURN(
836       std::unique_ptr<HloModuleConfig> module_config,
837       CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
838                          argument_shape_ptrs, &arg->execution_options()));
839   VLOG(3) << "Compile created HloModuleConfig computation layout: "
840           << module_config->entry_computation_layout().ToString();
841 
842   TF_ASSIGN_OR_RETURN(
843       std::unique_ptr<Executable> executable,
844       BuildExecutable(arg->computation(), std::move(module_config),
845                       execute_backend_.get(),
846                       execute_backend_->default_stream_executor(),
847                       /*device_allocator=*/nullptr));
848 
849   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
850 
851   VLOG(1) << "successfully completed 'compile' request";
852   return Status::OK();
853 }
854 
Execute(const ExecuteRequest * arg,ExecuteResponse * result)855 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
856   VLOG(1) << "running execute request";
857   if (!arg->has_handle()) {
858     return InvalidArgument("execution handle should not be empty");
859   }
860   TF_ASSIGN_OR_RETURN(auto executable,
861                       compilation_cache_.LookUp(arg->handle()));
862 
863   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
864                                               SingleComputationDeviceHandle()));
865   TF_ASSIGN_OR_RETURN(
866       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
867       ResolveAndValidateArguments(arg->arguments(), replicas));
868 
869   // Check that the replicated_arguments has the same shape and layout as the
870   // module config used when creating the exectuable.
871   const int64 num_module_args =
872       executable->module_config().entry_computation_layout().parameter_count();
873   if (num_module_args != arg->arguments_size()) {
874     return InvalidArgument(
875         "The executable expects %lld arguments, but sees %lld.",
876         num_module_args, arg->arguments_size());
877   }
878   for (int64 i = 0; i < num_module_args; i++) {
879     const Shape& shape_module =
880         executable->module_config().entry_computation_layout().parameter_shape(
881             i);
882     const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
883     if (!ShapeUtil::Equal(shape_module, shape_arg)) {
884       return InvalidArgumentStrCat(
885           "The executable exepcts the ", i, "th argument in shape ",
886           ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
887           ShapeUtil::HumanStringWithLayout(shape_arg));
888     }
889   }
890 
891   TF_ASSIGN_OR_RETURN(auto stream,
892                       execute_backend_->BorrowStream(
893                           execute_backend_->default_stream_executor()));
894   if (executable->dumping_snapshot()) {
895     executable->hlo_snapshot()->set_execution_platform(
896         execute_backend_->platform()->Name());
897     TF_RETURN_IF_ERROR(RecordArguments(
898         replicated_arguments.front(), stream.get(),
899         execute_backend_->transfer_manager(), executable->hlo_snapshot()));
900   }
901 
902   TF_ASSIGN_OR_RETURN(
903       *result->mutable_output(),
904       ExecuteAndRegisterResult(executable.get(), replicated_arguments,
905                                execute_backend_.get(),
906                                SingleComputationDeviceHandle(),
907                                "result of " + executable->module().name(),
908                                result->mutable_profile()));
909 
910   if (executable->dumping_snapshot()) {
911     TF_ASSIGN_OR_RETURN(
912         const ShapedBuffer* result_buffer,
913         allocation_tracker_.ResolveForReplica(result->output(), 0));
914     TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
915                                     execute_backend_->transfer_manager(),
916                                     executable->hlo_snapshot()));
917     DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot());
918   }
919 
920   VLOG(1) << "successfully completed 'execute' request";
921   return Status::OK();
922 }
923 
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)924 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
925                                  WaitForExecutionResponse* result) {
926   TF_ASSIGN_OR_RETURN(const auto execution,
927                       execution_tracker_.Resolve(arg->execution()));
928 
929   TF_RETURN_IF_ERROR(execution->BlockUntilDone());
930 
931   *result->mutable_output() = execution->result();
932   *result->mutable_profile() = execution->profile();
933 
934   TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
935   VLOG(1) << "successfully completed 'wait-for-execution' request";
936   return Status::OK();
937 }
938 
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)939 Status Service::TransferToClient(const TransferToClientRequest* arg,
940                                  TransferToClientResponse* result) {
941   TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
942                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
943 
944   Shape return_shape;
945   if (arg->has_shape_with_layout()) {
946     return_shape = Shape(arg->shape_with_layout());
947     if (!LayoutUtil::HasLayout(return_shape)) {
948       return InvalidArgument("shape_with_layout must have layout if present.");
949     }
950   } else {
951     return_shape = Shape(shaped_buffer->on_host_shape());
952   }
953 
954   TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
955                                        shaped_buffer->device_ordinal()));
956 
957   TF_ASSIGN_OR_RETURN(
958       Literal result_literal,
959       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
960           stream.get(), *shaped_buffer));
961 
962   if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
963     *result->mutable_literal() = result_literal.ToProto();
964   } else {
965     *result->mutable_literal() =
966         result_literal.Relayout(return_shape).ToProto();
967   }
968   return Status::OK();
969 }
970 
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)971 Status Service::TransferToServer(const TransferToServerRequest* arg,
972                                  TransferToServerResponse* result) {
973   TF_ASSIGN_OR_RETURN(Literal literal,
974                       Literal::CreateFromProto(arg->literal()));
975   const Shape& shape = literal.shape();
976 
977   std::vector<se::StreamExecutor*> replicas;
978   if (arg->has_device_handle()) {
979     TF_ASSIGN_OR_RETURN(replicas,
980                         Replicas(*execute_backend_, arg->device_handle()));
981   } else {
982     TF_ASSIGN_OR_RETURN(
983         replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
984   }
985 
986   // Allocate memory in each replica and transfer the data to all replicas.
987   std::vector<ScopedShapedBuffer> replicated_buffers;
988   for (se::StreamExecutor* executor : replicas) {
989     TF_ASSIGN_OR_RETURN(
990         ScopedShapedBuffer shaped_buffer,
991         execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
992             shape, execute_backend_->memory_allocator(),
993             executor->device_ordinal()));
994     TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
995     TF_RETURN_IF_ERROR(
996         execute_backend_->transfer_manager()->TransferLiteralToDevice(
997             stream.get(), literal, shaped_buffer));
998     replicated_buffers.emplace_back(std::move(shaped_buffer));
999   }
1000   TF_ASSIGN_OR_RETURN(*result->mutable_data(),
1001                       allocation_tracker_.RegisterReplicatedBuffers(
1002                           std::move(replicated_buffers),
1003                           StrCat("TransferToServer literal of shape ",
1004                                  ShapeUtil::HumanString(shape))));
1005 
1006   return Status::OK();
1007 }
1008 
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)1009 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
1010                                  TransferToInfeedResponse* result) {
1011   const int64 replica_count = options_.number_of_replicas();
1012   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1013     return FailedPrecondition(
1014         "%s",
1015         StrCat("The replica_id=", arg->replica_id(),
1016                " on TransferToInfeedRequest not in range [0, replica_count=",
1017                replica_count, ")."));
1018   }
1019 
1020   se::StreamExecutor* executor;
1021   if (arg->has_device_handle()) {
1022     TF_ASSIGN_OR_RETURN(auto replicas,
1023                         Replicas(*execute_backend_, arg->device_handle()));
1024     executor = replicas[arg->replica_id()];
1025   } else {
1026     TF_ASSIGN_OR_RETURN(
1027         auto replicas,
1028         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1029     executor = replicas[arg->replica_id()];
1030   }
1031 
1032   TF_ASSIGN_OR_RETURN(Literal literal,
1033                       Literal::CreateFromProto(arg->literal()));
1034   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1035                                                                        literal);
1036 }
1037 
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1038 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1039                                     TransferFromOutfeedResponse* result) {
1040   const int64 replica_count = options_.number_of_replicas();
1041   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1042     return FailedPrecondition(
1043         "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1044         arg->replica_id(), replica_count);
1045   }
1046 
1047   se::StreamExecutor* executor;
1048   if (arg->has_device_handle()) {
1049     TF_ASSIGN_OR_RETURN(auto replicas,
1050                         Replicas(*execute_backend_, arg->device_handle()));
1051     executor = replicas[arg->replica_id()];
1052   } else {
1053     TF_ASSIGN_OR_RETURN(
1054         auto replicas,
1055         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1056     executor = replicas[arg->replica_id()];
1057   }
1058 
1059   auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1060 
1061   TF_RETURN_IF_ERROR(
1062       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1063           executor, Shape(arg->shape_with_layout()), literal));
1064   *result->mutable_literal() = literal.ToProto();
1065   return Status::OK();
1066 }
1067 
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1068 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1069                             ResetDeviceResponse* result) {
1070   return execute_backend_->ResetDevices();
1071 }
1072 
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1073 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1074                                      ComputeConstantResponse* result) {
1075   if (!arg->has_computation()) {
1076     return InvalidArgument("computations may not be empty");
1077   }
1078   if (!arg->computation().has_host_program_shape()) {
1079     return InvalidArgument("program shape may not be empty");
1080   }
1081   if (arg->computation().host_program_shape().parameters_size() != 0) {
1082     return InvalidArgument(
1083         "constant computation may not depend on any parameters.");
1084   }
1085 
1086   ProgramShape program_shape(arg->computation().host_program_shape());
1087   TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1088   absl::optional<Layout> output_layout;
1089   if (arg->has_output_layout()) {
1090     output_layout = Layout::CreateFromProto(arg->output_layout());
1091     TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1092         *output_layout, program_shape.result()));
1093   }
1094 
1095   HloModuleConfig config(program_shape);
1096 
1097   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1098                       CreateModuleFromProto(arg->computation(), config));
1099 
1100   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1101                       DynamicDimensionInference::Run(module.get()));
1102 
1103   HloEvaluator evaluator;
1104   evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1105   TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1106 
1107   // Since the result layout is non-effective to the Evaluator results, explicit
1108   // relayout here.
1109   //
1110   // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1111   if (output_layout.has_value()) {
1112     result_literal = result_literal.Relayout(*output_layout);
1113   }
1114   *result->mutable_literal() = result_literal.ToProto();
1115 
1116   return Status::OK();
1117 }
1118 
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1119 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1120   TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1121                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
1122   *result->mutable_shape() = buffer->on_host_shape().ToProto();
1123   return Status::OK();
1124 }
1125 
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1126 Status Service::GetComputationGraphStats(
1127     const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1128   if (!arg->has_computation()) {
1129     return InvalidArgument("Computations may not be empty.");
1130   }
1131   if (!arg->computation().has_host_program_shape()) {
1132     return InvalidArgument("Program shape may not be empty.");
1133   }
1134 
1135   HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1136   config.set_debug_options(arg->debug_options());
1137   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1138                       CreateModuleFromProto(arg->computation(), config));
1139   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1140 
1141   // Run HLO analysis to get the computation statistics.
1142   HloCostAnalysis analysis(
1143       execute_backend_->compiler()->ShapeSizeBytesFunction());
1144 
1145   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1146 
1147   ComputationStats stats;
1148   stats.set_flop_count(analysis.flop_count());
1149   stats.set_transcendental_count(analysis.transcendental_count());
1150   *result->mutable_stats() = stats;
1151   return Status::OK();
1152 }
1153 
SingleComputationDeviceHandle() const1154 DeviceHandle Service::SingleComputationDeviceHandle() const {
1155   DeviceHandle device_handle;
1156   device_handle.set_handle(0);
1157   device_handle.set_device_count(1);
1158   return device_handle;
1159 }
1160 
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1161 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1162     const Backend& backend, const DeviceHandle& device_handle) const {
1163   std::vector<se::StreamExecutor*> replicas;
1164   for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1165     // From the computation placer, find out the device ids of the replicas for
1166     // the given device handle.
1167     TF_ASSIGN_OR_RETURN(
1168         int device_ordinal,
1169         backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1170                                                options_.number_of_replicas(),
1171                                                device_handle.device_count()));
1172     TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1173     replicas.push_back(executor);
1174   }
1175   return replicas;
1176 }
1177 
1178 }  // namespace xla
1179