• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/service.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/dump.h"
33 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
34 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
35 #include "tensorflow/compiler/xla/service/executable.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
43 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
44 #include "tensorflow/compiler/xla/service/platform_util.h"
45 #include "tensorflow/compiler/xla/service/source_map_util.h"
46 #include "tensorflow/compiler/xla/service/stream_pool.h"
47 #include "tensorflow/compiler/xla/service/transfer_manager.h"
48 #include "tensorflow/compiler/xla/shape.h"
49 #include "tensorflow/compiler/xla/shape_layout.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/platform/env.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
61 #include "tensorflow/core/platform/types.h"
62 #include "tensorflow/core/util/ptr_util.h"
63 #include "tensorflow/stream_executor/device_memory_allocator.h"
64 
65 namespace xla {
66 namespace {
67 
68 using absl::StrCat;
69 using absl::StrFormat;
70 
71 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
72 // performed on an HloModule.
73 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
74 
75 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)76 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
77                        se::Stream* stream, TransferManager* transfer_manager,
78                        HloSnapshot* module) {
79   module->clear_arguments();
80   for (const ShapedBuffer* argument : arguments) {
81     TF_ASSIGN_OR_RETURN(
82         Literal literal,
83         transfer_manager->TransferLiteralFromDevice(stream, *argument));
84     *module->add_arguments() = literal.ToProto();
85   }
86   return Status::OK();
87 }
88 
89 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)90 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
91                     TransferManager* transfer_manager, HloSnapshot* module) {
92   module->clear_result();
93   TF_ASSIGN_OR_RETURN(
94       Literal literal,
95       transfer_manager->TransferLiteralFromDevice(stream, result));
96   *module->mutable_result() = literal.ToProto();
97   return Status::OK();
98 }
99 
100 }  // namespace
101 
set_platform(se::Platform * platform)102 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
103   platform_ = platform;
104   return *this;
105 }
106 
platform() const107 se::Platform* ServiceOptions::platform() const { return platform_; }
108 
set_number_of_replicas(int number_of_replicas)109 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
110   number_of_replicas_ = number_of_replicas;
111   return *this;
112 }
113 
number_of_replicas() const114 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
115 
set_intra_op_parallelism_threads(int num_threads)116 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
117     int num_threads) {
118   intra_op_parallelism_threads_ = num_threads;
119   return *this;
120 }
121 
intra_op_parallelism_threads() const122 int ServiceOptions::intra_op_parallelism_threads() const {
123   return intra_op_parallelism_threads_;
124 }
125 
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)126 ServiceOptions& ServiceOptions::set_allowed_devices(
127     const absl::optional<std::set<int>>& allowed_devices) {
128   allowed_devices_ = allowed_devices;
129   return *this;
130 }
131 
allowed_devices() const132 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
133   return allowed_devices_;
134 }
135 
NewService(se::Platform * platform)136 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
137     se::Platform* platform) {
138   ServiceOptions default_options;
139   default_options.set_platform(platform);
140   return NewService(default_options);
141 }
142 
NewService(const ServiceOptions & options)143 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
144     const ServiceOptions& options) {
145   se::Platform* platform = options.platform();
146   std::unique_ptr<Backend> execute_backend;
147   if (platform == nullptr) {
148     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
149   }
150   BackendOptions backend_options;
151   backend_options.set_platform(platform);
152   backend_options.set_allowed_devices(options.allowed_devices());
153   TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
154 
155   std::unique_ptr<Service> service(
156       new Service(options, std::move(execute_backend)));
157   return std::move(service);
158 }
159 
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)160 Service::Service(const ServiceOptions& options,
161                  std::unique_ptr<Backend> execute_backend)
162     : options_(options),
163       allocation_tracker_(execute_backend.get()),
164       execute_backend_(std::move(execute_backend)) {
165   CHECK_GT(options_.number_of_replicas(), 0);
166   if (execute_backend_) {
167     if (execute_backend_->device_count() > 0) {
168       CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
169           << "Requested more replicas than there are devices.";
170     }
171     LOG(INFO) << StrFormat(
172         "XLA service %p initialized for platform %s (this does not guarantee "
173         "that XLA will be used). Devices:",
174         this, execute_backend_->platform()->Name());
175     auto stream_executors = execute_backend_->stream_executors();
176     for (int i = 0; i < execute_backend_->device_count(); ++i) {
177       se::StreamExecutor* executor = stream_executors.at(i);
178       const auto& description = executor->GetDeviceDescription();
179       LOG(INFO) << StrFormat("  StreamExecutor device (%d): %s, %s", i,
180                              description.name(),
181                              description.platform_version());
182     }
183   } else {
184     VLOG(1) << "XLA compile-only service constructed";
185   }
186 }
187 
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)188 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
189                                     CreateChannelHandleResponse* result) {
190   TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
191                       channel_tracker_.NewChannel(arg->channel_type()));
192   return Status::OK();
193 }
194 
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)195 Status Service::Unregister(const UnregisterRequest* arg,
196                            UnregisterResponse* result) {
197   Status status;
198   for (auto& data : arg->data()) {
199     Status unregister_status = allocation_tracker_.Unregister(data);
200     if (!unregister_status.ok() && status.ok()) {
201       status = unregister_status;
202     }
203   }
204   return status;
205 }
206 
207 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)208 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
209                                  DeconstructTupleResponse* result) {
210   TF_ASSIGN_OR_RETURN(
211       std::vector<GlobalDataHandle> elements,
212       allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
213 
214   for (auto& element : elements) {
215     *result->add_element_handles() = element;
216   }
217   return Status::OK();
218 }
219 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const220 Status Service::ValidateResultShape(const Shape& client_shape,
221                                     const Shape& result_shape) const {
222   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
223   if (!ShapeUtil::Compatible(client_shape, result_shape)) {
224     return InvalidArgument(
225         "Shape used to set computation result layout %s is not compatible "
226         "with result shape %s",
227         ShapeUtil::HumanStringWithLayout(client_shape),
228         ShapeUtil::HumanString(result_shape));
229   }
230   return Status::OK();
231 }
232 
233 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const234 Service::ResolveAndValidateArguments(
235     absl::Span<const GlobalDataHandle* const> arguments,
236     absl::Span<se::StreamExecutor* const> stream_executors) const {
237   CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
238   std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
239   replicated_arguments.resize(options_.number_of_replicas());
240   for (size_t i = 0; i < arguments.size(); ++i) {
241     auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
242     if (!buffer_status.ok()) {
243       return Status(buffer_status.status().code(),
244                     StrCat(buffer_status.status().error_message(), ", ",
245                            "failed to resolve allocation for parameter ", i));
246     }
247     auto replicated_buffers = buffer_status.ValueOrDie();
248     CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
249     for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
250       const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
251       replicated_arguments[replica].push_back(shaped_buffer);
252     }
253   }
254   return replicated_arguments;
255 }
256 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)257 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
258     const ProgramShape& program_shape,
259     absl::Span<const Shape* const> argument_shapes,
260     const ExecutionOptions* execution_options,
261     const AotCompilationOptions* aot_options) {
262   int default_num_replicas = options_.number_of_replicas();
263   absl::optional<int> num_threads;
264   if (execute_backend_ != nullptr &&
265       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
266     num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
267   }
268 
269   return xla::CreateModuleConfig(program_shape, argument_shapes,
270                                  execution_options, default_num_replicas,
271                                  num_threads, aot_options);
272 }
273 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)274 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
275     const ProgramShape& program_shape,
276     absl::Span<const ShapedBuffer* const> arguments,
277     const ExecutionOptions& execution_options,
278     const AotCompilationOptions* aot_options) {
279   std::vector<const Shape*> argument_shapes;
280   for (const auto* arg : arguments) {
281     argument_shapes.push_back(&arg->on_host_shape());
282   }
283   return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
284                             aot_options);
285 }
286 
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)287 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
288     const std::vector<const HloModuleProto*>& module_protos,
289     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
290     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
291     const Compiler::CompileOptions& options, bool run_backend_only) {
292   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
293 
294   VLOG(1) << "Computations:";
295   for (const HloModuleProto* proto : module_protos) {
296     VLOG(1) << proto->name();
297   }
298 
299   CHECK_EQ(module_protos.size(), module_configs.size());
300   auto module_group =
301       absl::make_unique<HloModuleGroup>(module_protos[0]->name());
302   for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
303     const HloModuleProto* proto = module_protos[i];
304     const HloModuleConfig& config = *module_configs[i];
305     TF_ASSIGN_OR_RETURN(
306         auto module, CreateModuleFromProto(*proto, config, run_backend_only));
307     DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
308     module_group->push_back(std::move(module));
309   }
310 
311   std::vector<std::unique_ptr<Executable>> executables;
312   if (!run_backend_only) {
313     TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
314                                          std::move(module_group),
315                                          std::move(executors), options));
316   } else {
317     auto modules = module_group->ConsumeModules();
318     for (std::unique_ptr<HloModule>& module : modules) {
319       TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
320                           backend->compiler()->RunBackend(
321                               std::move(module), executors[0][0], options));
322       executables.push_back(std::move(executable));
323     }
324   }
325 
326   return std::move(executables);
327 }
328 
329 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)330 Service::ExecuteParallelAndRegisterResult(
331     absl::Span<Executable* const> executables,
332     absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
333     Backend* backend, absl::Span<const DeviceHandle> device_handles,
334     absl::Span<const string> result_tags, ExecutionProfile* profile) {
335   // Streams where the computation are launched, so we can wait on the streams
336   // to complete.
337   std::vector<StreamPool::Ptr> streams;
338   std::vector<std::unique_ptr<se::Timer>> timers;
339 
340   // Global data handles for the computation results, one for each computation.
341   std::vector<GlobalDataHandle> result_handles;
342 
343   // Device ID to stream executor, populated only with devices that are being
344   // profiled.
345   std::map<int64, se::Stream*> index_to_profiled_streams;
346 
347   // Build DeviceAssignment for all cores based on the provided device handles.
348   DeviceAssignment device_assignment(options_.number_of_replicas(),
349                                      executables.size());
350   for (int64_t i = 0; i < executables.size(); i++) {
351     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
352     CHECK_EQ(replicas.size(), arguments[i].size());
353     for (int64_t replica = 0, end = replicas.size(); replica < end; ++replica) {
354       device_assignment(replica, i) = replicas[replica]->device_ordinal();
355     }
356   }
357 
358   for (int64_t i = 0, end = executables.size(); i < end; i++) {
359     // Stream executors for the replicas of the current computation.
360     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
361     CHECK_EQ(replicas.size(), arguments[i].size());
362     std::vector<ScopedShapedBuffer> result_buffers;
363     for (int64_t replica = 0; replica < replicas.size(); ++replica) {
364       TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
365                           backend->BorrowStream(replicas[replica]));
366       streams.push_back(std::move(stream));
367 
368       if (replica == 0 && profile != nullptr) {
369         timers.push_back(
370             absl::make_unique<se::Timer>(streams.back()->parent()));
371         streams.back()
372             ->InitTimer(timers.back().get())
373             .ThenStartTimer(timers.back().get());
374         CHECK(timers.front() != nullptr);
375       }
376 
377       if (replica == 0 &&
378           executables[i]->module_config().debug_options().xla_hlo_profile() &&
379           executables[i]->hlo_profiling_enabled()) {
380         index_to_profiled_streams[i] = streams.back().get();
381       }
382 
383       // Set up run options.
384       ExecutableRunOptions options;
385       options.set_stream(streams.back().get());
386       options.set_allocator(backend->memory_allocator());
387       options.set_intra_op_thread_pool(
388           backend->eigen_intra_op_thread_pool_device());
389       options.set_device_assignment(&device_assignment);
390       // Use run-time profile information from execution_profile on the 0th
391       // device.
392       if (i == 0) {
393         options.set_execution_profile(profile);
394       }
395       ServiceExecutableRunOptions run_options(options,
396                                               backend->StreamBorrower());
397 
398       // Asynchronously launch the computation.
399       TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
400                           executables[i]->ExecuteAsyncOnStream(
401                               &run_options, arguments[i][replica],
402                               /*hlo_execution_profile=*/nullptr));
403 
404       if (replica == 0 && profile != nullptr) {
405         streams.back()->ThenStopTimer(timers.back().get());
406       }
407 
408       result_buffers.push_back(std::move(result));
409     }
410     TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
411                         allocation_tracker_.RegisterReplicatedBuffers(
412                             std::move(result_buffers), result_tags[i]));
413     result_handles.push_back(handle);
414   }
415 
416   // Wait for all executions to complete.
417   for (int64_t i = 0, end = streams.size(); i < end; ++i) {
418     Status block_status = streams[i]->BlockHostUntilDone();
419     if (!block_status.ok()) {
420       return InternalError("failed to complete execution for stream %d: %s", i,
421                            block_status.error_message());
422     }
423   }
424 
425   if (profile != nullptr) {
426     CHECK(!timers.empty());
427     std::vector<uint64> timer_nanoseconds;
428     timer_nanoseconds.reserve(timers.size());
429     for (auto& timer : timers) {
430       timer_nanoseconds.push_back(timer->Nanoseconds());
431     }
432     uint64 nanoseconds =
433         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
434 
435     // Overall execution time (in nanoseconds) from the executor timer.
436     profile->set_compute_and_transfer_time_ns(nanoseconds);
437 
438     // TODO(b/28123297): On GPU we end up including transfer time in
439     // the compute time this way. Instead, we should get the correct
440     // value by measuring it. Setting the field here at least lets
441     // benchmarks provide *some* value for GPU computations.
442     //
443     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
444     // the compute time without the transfer time, so this way we get the
445     // correct compute time. We should instead have the correct value for
446     // compute_and_transfer_time and set compute_time to the compute time.
447     if (profile->compute_time_ns() == 0) {
448       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
449     }
450   }
451 
452   return result_handles;
453 }
454 
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)455 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
456     Executable* executable,
457     absl::Span<const std::vector<const ShapedBuffer*>> arguments,
458     Backend* backend, const DeviceHandle& device_handle,
459     const string& result_tag, ExecutionProfile* profile) {
460   // Set up streams.
461   std::vector<StreamPool::Ptr> streams;
462 
463   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
464   TF_RET_CHECK(!replicas.empty());
465   for (se::StreamExecutor* executor : replicas) {
466     TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
467                         backend->BorrowStream(executor));
468     streams.push_back(std::move(stream));
469   }
470 
471   DeviceAssignment device_assignment(options_.number_of_replicas(),
472                                      /*computation_count=*/1);
473   for (int64_t replica = 0; replica < replicas.size(); ++replica) {
474     device_assignment(replica, 0) = replicas[replica]->device_ordinal();
475   }
476 
477   // Set up run options.
478   std::vector<ServiceExecutableRunOptions> run_options;
479   for (const StreamPool::Ptr& stream : streams) {
480     ExecutableRunOptions options;
481     options.set_stream(stream.get());
482     options.set_device_ordinal(stream->parent()->device_ordinal());
483     options.set_allocator(backend->memory_allocator());
484     options.set_intra_op_thread_pool(
485         backend->eigen_intra_op_thread_pool_device());
486     options.set_device_assignment(&device_assignment);
487     options.set_execution_profile(profile);
488     run_options.emplace_back(options, backend->StreamBorrower());
489   }
490 
491   if (options_.number_of_replicas() == 1) {
492     TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
493                                          &run_options[0], arguments[0]));
494     return allocation_tracker_.Register(std::move(result), result_tag);
495   }
496 
497   // TODO(b/69985541): Support profiling also on this path.
498 
499   std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
500   for (const auto& arg : arguments) {
501     replicated_arguments.push_back(arg);
502   }
503 
504   TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
505                                         run_options, replicated_arguments));
506   TF_RET_CHECK(!results.empty());
507   return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
508                                                        result_tag);
509 }
510 
GetExecutors(const ExecutionOptions & execution_options,int64_t requests_size,int64_t request_index) const511 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
512     const ExecutionOptions& execution_options, int64_t requests_size,
513     int64_t request_index) const {
514   if (execution_options.device_handles().empty()) {
515     return FailedPrecondition(
516         "device handles must be given to execute parallel computations");
517   }
518   if (requests_size > 1 && execution_options.device_handles_size() > 1) {
519     return InvalidArgument(
520         "Parallel requests with multiple device handles is not supported. "
521         "Found %d parallel requests, with request %d containing %d device "
522         "handles.",
523         requests_size, request_index, execution_options.device_handles_size());
524   }
525   std::vector<se::StreamExecutor*> executors;
526   for (const auto& device_handle : execution_options.device_handles()) {
527     TF_ASSIGN_OR_RETURN(auto replicas,
528                         Replicas(*execute_backend_, device_handle));
529     se::StreamExecutor* executor = replicas[0];
530     CHECK(executor != nullptr);
531     executors.push_back(executor);
532   }
533   return executors;
534 }
535 
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const536 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
537     const ExecutionOptions& execution_options,
538     absl::Span<const GlobalDataHandle* const> arguments) const {
539   // Resolve the allocations for the arguments of the computation, and create
540   // a vector of device memory offsets for the arguments from the allocations.
541   // In the case of partitioned computations, assume all arguments go on the
542   // zeroth core.
543   TF_ASSIGN_OR_RETURN(
544       auto replicas,
545       Replicas(*execute_backend_, execution_options.device_handles(0)));
546   TF_ASSIGN_OR_RETURN(
547       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
548       ResolveAndValidateArguments(arguments, replicas));
549   return replicated_arguments;
550 }
551 
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)552 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
553                                      ExecuteParallelResponse* result) {
554   VLOG(1) << "running execute-graph-parallel request";
555 
556   std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
557   std::vector<std::vector<se::StreamExecutor*>> all_executors;
558   std::vector<const HloModuleProto*> module_protos;
559   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
560   std::vector<string> computation_names;
561   std::vector<DeviceHandle> device_handles;
562 
563   int num_requested_devices =
564       std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
565                       [](int a, const ExecuteGraphRequest& r) -> int {
566                         return a + r.execution_options().device_handles_size();
567                       });
568   if (num_requested_devices * options_.number_of_replicas() >
569       execute_backend_->device_count()) {
570     return FailedPrecondition(
571         "there are not enough stream executors to execute %d computations",
572         num_requested_devices);
573   }
574 
575   for (int64_t i = 0; i < arg->requests_size(); ++i) {
576     // Get the stream executor for the i'th computation. This stream executor
577     // is one of the executors to run the replicated computation.
578     const ExecutionOptions& execution_options =
579         arg->requests(i).execution_options();
580     const ExecuteGraphRequest& request = arg->requests(i);
581     TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
582     TF_RET_CHECK(request.computation().has_host_program_shape())
583         << "program shape may not be empty";
584 
585     // Get the executors.
586     TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
587                                                      arg->requests_size(), i));
588 
589     // Get the replicated arguments.
590     TF_ASSIGN_OR_RETURN(auto replicated_arguments,
591                         GetArguments(execution_options, request.arguments()));
592 
593     // Create an HloModuleConfig object for the computation, given the shape of
594     // the program and the argument allocations. Here, we care only about the
595     // shapes of the arguments, so, it is sufficient to use the arguments of
596     // replica 0.
597     TF_ASSIGN_OR_RETURN(
598         std::unique_ptr<HloModuleConfig> module_config,
599         CreateModuleConfig(
600             ProgramShape{request.computation().host_program_shape()},
601             replicated_arguments.front(), request.execution_options()));
602     VLOG(3)
603         << "ExecuteGraphParallel created HloModuleConfig computation layout: "
604         << module_config->entry_computation_layout().ToString();
605 
606     // Adds to the vectors to build and execute the computations after the loop.
607     all_arguments.push_back(replicated_arguments);
608     all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
609     module_protos.push_back(&request.computation());
610     module_configs.push_back(std::move(module_config));
611     computation_names.insert(computation_names.end(), executors.size(),
612                              request.computation().name());
613     all_executors.push_back(executors);
614     device_handles.insert(device_handles.end(),
615                           execution_options.device_handles().begin(),
616                           execution_options.device_handles().end());
617   }
618 
619   // Build the HloModules and compile to generate the executables.
620   //
621   // TODO(jlebar): There's currently no way to pass a device allocator to
622   // ExecuteGraphParallel, so we have to pass a null device_allocator below.
623   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
624                       BuildExecutables(module_protos, std::move(module_configs),
625                                        execute_backend_.get(), all_executors,
626                                        {/*device_allocator=*/nullptr}));
627   std::vector<Executable*> executable_ptrs;
628   executable_ptrs.reserve(executables.size());
629   for (const auto& executable : executables) {
630     executable_ptrs.push_back(executable.get());
631   }
632 
633   std::vector<HloSnapshot> snapshots;
634   snapshots.resize(executable_ptrs.size());
635   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
636     if (executable_ptrs[i]->dumping_snapshot()) {
637       *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
638       TF_ASSIGN_OR_RETURN(auto stream,
639                           execute_backend_->BorrowStream(
640                               all_executors[i][0]->device_ordinal()));
641       TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
642                                          execute_backend_->transfer_manager(),
643                                          &snapshots[i]));
644     }
645   }
646 
647   // If we have multiple executables to run, execute them all in parallel.  But
648   // if we only have one executable, execute it using the vanilla, non-parallel
649   // call.
650   //
651   // We do this because the Client API uses ExecuteGraphParallel when it wants
652   // to compile and run one computation without caching the executable, but not
653   // all backends support the async StreamExecutor API required by
654   // ExecuteParallelAndRegisterResult.
655   //
656   // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
657   // basically the same thing.
658   ExecutionProfile profile;
659   std::vector<GlobalDataHandle> outputs;
660   Status execution_status = Status::OK();
661 
662   if (executable_ptrs.size() == 1) {
663     StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
664         executable_ptrs[0], all_arguments[0], execute_backend_.get(),
665         device_handles[0], computation_names[0], &profile);
666     if (output_or_status.ok()) {
667       outputs.push_back(std::move(output_or_status).ValueOrDie());
668     } else {
669       execution_status = output_or_status.status();
670     }
671   } else {
672     StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
673         ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
674                                          execute_backend_.get(), device_handles,
675                                          computation_names, &profile);
676     if (outputs_or_status.ok()) {
677       outputs = std::move(outputs_or_status).ValueOrDie();
678     } else {
679       execution_status = outputs_or_status.status();
680     }
681   }
682 
683   if (!execution_status.ok()) {
684     // Execution failed so we don't have the results.  Dump the HLO snapshot
685     // with just the program arguments.
686     for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
687       DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
688     }
689   }
690 
691   TF_RETURN_IF_ERROR(execution_status);
692 
693   for (const GlobalDataHandle& output : outputs) {
694     ExecuteResponse response;
695     *response.mutable_output() = output;
696     *response.mutable_profile() = profile;
697     *result->add_responses() = response;
698   }
699 
700   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
701     Executable* executable = executable_ptrs[i];
702     if (executable->dumping_snapshot()) {
703       TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
704                           allocation_tracker_.ResolveForReplica(outputs[i], 0));
705       TF_ASSIGN_OR_RETURN(auto stream,
706                           execute_backend_->BorrowStream(all_executors[i][0]));
707       TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
708                                       execute_backend_->transfer_manager(),
709                                       &snapshots[i]));
710       DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
711     }
712   }
713 
714   VLOG(1) << "successfully completed 'execute-graph-parallel' request";
715   return Status::OK();
716 }
717 
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)718 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
719                                  GetDeviceHandlesResponse* result) {
720   const int64_t available_device_count = execute_backend_->device_count();
721   const int64_t replica_count = options_.number_of_replicas();
722   if (replica_count <= 0) {
723     return FailedPrecondition("Replica count must be a positive integer");
724   }
725   if (available_device_count < arg->device_count() * replica_count) {
726     return ResourceExhausted(
727         "Requested logical device count (%d) with replica count (%d) exceeds "
728         "the number of available physical devices on the target (%d)",
729         arg->device_count(), replica_count, available_device_count);
730   }
731 
732   for (int64_t i = 0; i < arg->device_count(); ++i) {
733     DeviceHandle device_handle;
734     device_handle.set_handle(i);
735     device_handle.set_device_count(arg->device_count());
736     *result->add_device_handles() = device_handle;
737   }
738 
739   return Status::OK();
740 }
741 
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)742 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
743     const HloModuleProto& module_proto,
744     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
745     se::StreamExecutor* executor, const Compiler::CompileOptions& options,
746     bool run_backend_only) {
747   VLOG(1) << StrFormat(
748       "BuildExecutable on service %p with serialized module proto: %s", this,
749       module_proto.name());
750 
751   TF_ASSIGN_OR_RETURN(
752       std::unique_ptr<HloModule> module,
753       CreateModuleFromProto(module_proto, *module_config, run_backend_only));
754   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
755 
756   if (!run_backend_only) {
757     TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
758                                     std::move(module), executor, options));
759   }
760 
761   TF_ASSIGN_OR_RETURN(
762       std::unique_ptr<Executable> executable,
763       backend->compiler()->RunBackend(std::move(module), executor, options));
764 
765   return std::move(executable);
766 }
767 
Compile(const CompileRequest * arg,CompileResponse * result)768 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
769   VLOG(1) << "running compile request";
770   if (!arg->has_computation()) {
771     return InvalidArgument("computations may not be empty");
772   }
773   if (!arg->computation().has_host_program_shape()) {
774     return InvalidArgument("program shape may not be empty");
775   }
776 
777   if (arg->execution_options().device_handles_size() > 1) {
778     return InvalidArgument(
779         "The compile request does not support multiple device handles.");
780   }
781 
782   std::vector<Shape> argument_shapes;
783   argument_shapes.reserve(arg->input_shape_with_layout_size());
784   std::vector<const Shape*> argument_shape_ptrs;
785   for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
786     argument_shapes.push_back(Shape(shape_proto));
787     argument_shape_ptrs.push_back(&argument_shapes.back());
788   }
789   TF_ASSIGN_OR_RETURN(
790       std::unique_ptr<HloModuleConfig> module_config,
791       CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
792                          argument_shape_ptrs, &arg->execution_options()));
793   VLOG(3) << "Compile created HloModuleConfig computation layout: "
794           << module_config->entry_computation_layout().ToString();
795 
796   TF_ASSIGN_OR_RETURN(
797       std::unique_ptr<Executable> executable,
798       BuildExecutable(arg->computation(), std::move(module_config),
799                       execute_backend_.get(),
800                       execute_backend_->default_stream_executor(),
801                       {/*device_allocator=*/nullptr}));
802 
803   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
804 
805   VLOG(1) << "successfully completed 'compile' request";
806   return Status::OK();
807 }
808 
Execute(const ExecuteRequest * arg,ExecuteResponse * result)809 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
810   VLOG(1) << "running execute request";
811   if (!arg->has_handle()) {
812     return InvalidArgument("execution handle should not be empty");
813   }
814   TF_ASSIGN_OR_RETURN(auto executable,
815                       compilation_cache_.LookUp(arg->handle()));
816 
817   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
818                                               SingleComputationDeviceHandle()));
819   TF_ASSIGN_OR_RETURN(
820       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
821       ResolveAndValidateArguments(arg->arguments(), replicas));
822 
823   // Check that the replicated_arguments has the same shape and layout as the
824   // module config used when creating the executable.
825   const int64_t num_module_args =
826       executable->module_config().entry_computation_layout().parameter_count();
827   if (num_module_args != arg->arguments_size()) {
828     return InvalidArgument(
829         "The executable expects %lld arguments, but sees %lld.",
830         num_module_args, arg->arguments_size());
831   }
832   for (int64_t i = 0; i < num_module_args; i++) {
833     const Shape& shape_module =
834         executable->module_config().entry_computation_layout().parameter_shape(
835             i);
836     const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
837     if (!ShapeUtil::Equal(shape_module, shape_arg)) {
838       return InvalidArgumentStrCat(
839           "The executable expects the ", i, "th argument in shape ",
840           ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
841           ShapeUtil::HumanStringWithLayout(shape_arg));
842     }
843   }
844 
845   TF_ASSIGN_OR_RETURN(auto stream,
846                       execute_backend_->BorrowStream(
847                           execute_backend_->default_stream_executor()));
848   HloSnapshot snapshot;
849   if (executable->dumping_snapshot()) {
850     *snapshot.mutable_hlo() = *executable->hlo_proto();
851     snapshot.set_execution_platform(execute_backend_->platform()->Name());
852     TF_RETURN_IF_ERROR(
853         RecordArguments(replicated_arguments.front(), stream.get(),
854                         execute_backend_->transfer_manager(), &snapshot));
855   }
856 
857   TF_ASSIGN_OR_RETURN(
858       *result->mutable_output(),
859       ExecuteAndRegisterResult(executable.get(), replicated_arguments,
860                                execute_backend_.get(),
861                                SingleComputationDeviceHandle(),
862                                "result of " + executable->module().name(),
863                                result->mutable_profile()));
864 
865   if (executable->dumping_snapshot()) {
866     TF_ASSIGN_OR_RETURN(
867         const ShapedBuffer* result_buffer,
868         allocation_tracker_.ResolveForReplica(result->output(), 0));
869     TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
870                                     execute_backend_->transfer_manager(),
871                                     &snapshot));
872     DumpHloSnapshotIfEnabled(executable->module(), snapshot);
873   }
874 
875   VLOG(1) << "successfully completed 'execute' request";
876   return Status::OK();
877 }
878 
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)879 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
880                                  WaitForExecutionResponse* result) {
881   TF_ASSIGN_OR_RETURN(const auto execution,
882                       execution_tracker_.Resolve(arg->execution()));
883 
884   TF_RETURN_IF_ERROR(execution->BlockUntilDone());
885 
886   *result->mutable_output() = execution->result();
887   *result->mutable_profile() = execution->profile();
888 
889   TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
890   VLOG(1) << "successfully completed 'wait-for-execution' request";
891   return Status::OK();
892 }
893 
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)894 Status Service::TransferToClient(const TransferToClientRequest* arg,
895                                  TransferToClientResponse* result) {
896   TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
897                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
898 
899   Shape return_shape;
900   if (arg->has_shape_with_layout()) {
901     return_shape = Shape(arg->shape_with_layout());
902     if (!LayoutUtil::HasLayout(return_shape)) {
903       return InvalidArgument("shape_with_layout must have layout if present.");
904     }
905   } else {
906     return_shape = Shape(shaped_buffer->on_host_shape());
907   }
908 
909   TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
910                                        shaped_buffer->device_ordinal()));
911 
912   TF_ASSIGN_OR_RETURN(
913       Literal result_literal,
914       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
915           stream.get(), *shaped_buffer));
916 
917   if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
918     *result->mutable_literal() = result_literal.ToProto();
919   } else {
920     *result->mutable_literal() =
921         result_literal.Relayout(return_shape).ToProto();
922   }
923   return Status::OK();
924 }
925 
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)926 Status Service::TransferToServer(const TransferToServerRequest* arg,
927                                  TransferToServerResponse* result) {
928   TF_ASSIGN_OR_RETURN(Literal literal,
929                       Literal::CreateFromProto(arg->literal()));
930   const Shape& shape = literal.shape();
931 
932   std::vector<se::StreamExecutor*> replicas;
933   if (arg->has_device_handle()) {
934     TF_ASSIGN_OR_RETURN(replicas,
935                         Replicas(*execute_backend_, arg->device_handle()));
936   } else {
937     TF_ASSIGN_OR_RETURN(
938         replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
939   }
940 
941   // Allocate memory in each replica and transfer the data to all replicas.
942   std::vector<ScopedShapedBuffer> replicated_buffers;
943   for (se::StreamExecutor* executor : replicas) {
944     TF_ASSIGN_OR_RETURN(
945         ScopedShapedBuffer shaped_buffer,
946         execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
947             shape, execute_backend_->memory_allocator(),
948             executor->device_ordinal()));
949     TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
950     TF_RETURN_IF_ERROR(
951         execute_backend_->transfer_manager()->TransferLiteralToDevice(
952             stream.get(), literal, shaped_buffer));
953     replicated_buffers.emplace_back(std::move(shaped_buffer));
954   }
955   TF_ASSIGN_OR_RETURN(*result->mutable_data(),
956                       allocation_tracker_.RegisterReplicatedBuffers(
957                           std::move(replicated_buffers),
958                           StrCat("TransferToServer literal of shape ",
959                                  ShapeUtil::HumanString(shape))));
960 
961   return Status::OK();
962 }
963 
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)964 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
965                                  TransferToInfeedResponse* result) {
966   const int64_t replica_count = options_.number_of_replicas();
967   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
968     return FailedPrecondition(
969         "%s",
970         StrCat("The replica_id=", arg->replica_id(),
971                " on TransferToInfeedRequest not in range [0, replica_count=",
972                replica_count, ")."));
973   }
974 
975   se::StreamExecutor* executor;
976   if (arg->has_device_handle()) {
977     TF_ASSIGN_OR_RETURN(auto replicas,
978                         Replicas(*execute_backend_, arg->device_handle()));
979     executor = replicas[arg->replica_id()];
980   } else {
981     TF_ASSIGN_OR_RETURN(
982         auto replicas,
983         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
984     executor = replicas[arg->replica_id()];
985   }
986 
987   TF_ASSIGN_OR_RETURN(Literal literal,
988                       Literal::CreateFromProto(arg->literal()));
989   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
990                                                                        literal);
991 }
992 
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)993 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
994                                     TransferFromOutfeedResponse* result) {
995   const int64_t replica_count = options_.number_of_replicas();
996   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
997     return FailedPrecondition(
998         "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
999         arg->replica_id(), replica_count);
1000   }
1001 
1002   se::StreamExecutor* executor;
1003   if (arg->has_device_handle()) {
1004     TF_ASSIGN_OR_RETURN(auto replicas,
1005                         Replicas(*execute_backend_, arg->device_handle()));
1006     executor = replicas[arg->replica_id()];
1007   } else {
1008     TF_ASSIGN_OR_RETURN(
1009         auto replicas,
1010         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1011     executor = replicas[arg->replica_id()];
1012   }
1013 
1014   auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1015 
1016   TF_RETURN_IF_ERROR(
1017       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1018           executor, &literal));
1019   *result->mutable_literal() = literal.ToProto();
1020   return Status::OK();
1021 }
1022 
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1023 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1024                             ResetDeviceResponse* result) {
1025   return execute_backend_->ResetDevices();
1026 }
1027 
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1028 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1029                                      ComputeConstantResponse* result) {
1030   if (!arg->has_computation()) {
1031     return InvalidArgument("computations may not be empty");
1032   }
1033   if (!arg->computation().has_host_program_shape()) {
1034     return InvalidArgument("program shape may not be empty");
1035   }
1036   if (arg->computation().host_program_shape().parameters_size() != 0) {
1037     return InvalidArgument(
1038         "constant computation may not depend on any parameters.");
1039   }
1040 
1041   ProgramShape program_shape(arg->computation().host_program_shape());
1042   TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1043   absl::optional<Layout> output_layout;
1044   if (arg->has_output_layout()) {
1045     output_layout = Layout::CreateFromProto(arg->output_layout());
1046     TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1047         *output_layout, program_shape.result()));
1048   }
1049 
1050   HloModuleConfig config(program_shape);
1051 
1052   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1053                       CreateModuleFromProto(arg->computation(), config));
1054   DynamicPadder dynamic_padder;
1055   TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status());
1056 
1057   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1058                       DynamicDimensionInference::Run(module.get()));
1059 
1060   HloEvaluator evaluator;
1061   evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1062   evaluator.set_custom_call_handler(
1063       [](HloInstruction* custom_call,
1064          absl::Span<const Literal*> operands) -> StatusOr<Literal> {
1065         if (custom_call->custom_call_target() == "SliceToDynamic") {
1066           auto result = operands[0]->Clone();
1067           for (int64_t i = 0; i < result.shape().rank(); ++i) {
1068             result.SetDynamicSize(i, operands[1 + i]->Get<int32>({}));
1069           }
1070           return result.ToStatic();
1071         }
1072         return Unimplemented("Custom call %s is not supported: %s",
1073                              custom_call->custom_call_target(),
1074                              custom_call->ToString());
1075       });
1076   TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1077 
1078   // Since the result layout is non-effective to the Evaluator results, explicit
1079   // relayout here.
1080   //
1081   // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1082   if (output_layout.has_value()) {
1083     result_literal = result_literal.Relayout(*output_layout);
1084   }
1085   *result->mutable_literal() = result_literal.ToProto();
1086 
1087   return Status::OK();
1088 }
1089 
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1090 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1091   TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1092                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
1093   *result->mutable_shape() = buffer->on_host_shape().ToProto();
1094   return Status::OK();
1095 }
1096 
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1097 Status Service::GetComputationGraphStats(
1098     const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1099   if (!arg->has_computation()) {
1100     return InvalidArgument("Computations may not be empty.");
1101   }
1102   if (!arg->computation().has_host_program_shape()) {
1103     return InvalidArgument("Program shape may not be empty.");
1104   }
1105 
1106   HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1107   config.set_debug_options(arg->debug_options());
1108   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1109                       CreateModuleFromProto(arg->computation(), config));
1110   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1111 
1112   // Run HLO analysis to get the computation statistics.
1113   HloCostAnalysis analysis(
1114       execute_backend_->compiler()->ShapeSizeBytesFunction());
1115 
1116   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1117 
1118   ComputationStats stats;
1119   stats.set_flop_count(analysis.flop_count());
1120   stats.set_transcendental_count(analysis.transcendental_count());
1121   *result->mutable_stats() = stats;
1122   return Status::OK();
1123 }
1124 
SingleComputationDeviceHandle() const1125 DeviceHandle Service::SingleComputationDeviceHandle() const {
1126   DeviceHandle device_handle;
1127   device_handle.set_handle(0);
1128   device_handle.set_device_count(1);
1129   return device_handle;
1130 }
1131 
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1132 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1133     const Backend& backend, const DeviceHandle& device_handle) const {
1134   std::vector<se::StreamExecutor*> replicas;
1135   for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1136     // From the computation placer, find out the device ids of the replicas for
1137     // the given device handle.
1138     TF_ASSIGN_OR_RETURN(
1139         int device_ordinal,
1140         backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1141                                                options_.number_of_replicas(),
1142                                                device_handle.device_count()));
1143     TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1144     replicas.push_back(executor);
1145   }
1146   return replicas;
1147 }
1148 
1149 }  // namespace xla
1150