• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/service.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/dump.h"
33 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
34 #include "tensorflow/compiler/xla/service/executable.h"
35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
36 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
37 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_module.h"
40 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
42 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
43 #include "tensorflow/compiler/xla/service/platform_util.h"
44 #include "tensorflow/compiler/xla/service/source_map_util.h"
45 #include "tensorflow/compiler/xla/service/stream_pool.h"
46 #include "tensorflow/compiler/xla/service/transfer_manager.h"
47 #include "tensorflow/compiler/xla/shape.h"
48 #include "tensorflow/compiler/xla/shape_layout.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
59 #include "tensorflow/core/platform/types.h"
60 #include "tensorflow/core/util/ptr_util.h"
61 #include "tensorflow/stream_executor/device_memory_allocator.h"
62 
63 namespace xla {
64 namespace {
65 
66 using absl::StrCat;
67 using absl::StrFormat;
68 
69 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
70 // performed on an HloModule.
71 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
72 
73 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)74 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
75                        se::Stream* stream, TransferManager* transfer_manager,
76                        HloSnapshot* module) {
77   module->clear_arguments();
78   for (const ShapedBuffer* argument : arguments) {
79     TF_ASSIGN_OR_RETURN(
80         Literal literal,
81         transfer_manager->TransferLiteralFromDevice(stream, *argument));
82     *module->add_arguments() = literal.ToProto();
83   }
84   return Status::OK();
85 }
86 
87 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)88 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
89                     TransferManager* transfer_manager, HloSnapshot* module) {
90   module->clear_result();
91   TF_ASSIGN_OR_RETURN(
92       Literal literal,
93       transfer_manager->TransferLiteralFromDevice(stream, result));
94   *module->mutable_result() = literal.ToProto();
95   return Status::OK();
96 }
97 
98 }  // namespace
99 
set_platform(se::Platform * platform)100 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
101   platform_ = platform;
102   return *this;
103 }
104 
platform() const105 se::Platform* ServiceOptions::platform() const { return platform_; }
106 
set_number_of_replicas(int number_of_replicas)107 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
108   number_of_replicas_ = number_of_replicas;
109   return *this;
110 }
111 
number_of_replicas() const112 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
113 
set_intra_op_parallelism_threads(int num_threads)114 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
115     int num_threads) {
116   intra_op_parallelism_threads_ = num_threads;
117   return *this;
118 }
119 
intra_op_parallelism_threads() const120 int ServiceOptions::intra_op_parallelism_threads() const {
121   return intra_op_parallelism_threads_;
122 }
123 
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)124 ServiceOptions& ServiceOptions::set_allowed_devices(
125     const absl::optional<std::set<int>>& allowed_devices) {
126   allowed_devices_ = allowed_devices;
127   return *this;
128 }
129 
allowed_devices() const130 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
131   return allowed_devices_;
132 }
133 
NewService(se::Platform * platform)134 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
135     se::Platform* platform) {
136   ServiceOptions default_options;
137   default_options.set_platform(platform);
138   return NewService(default_options);
139 }
140 
NewService(const ServiceOptions & options)141 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
142     const ServiceOptions& options) {
143   se::Platform* platform = options.platform();
144   std::unique_ptr<Backend> execute_backend;
145   if (platform == nullptr) {
146     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
147   }
148   BackendOptions backend_options;
149   backend_options.set_platform(platform);
150   backend_options.set_allowed_devices(options.allowed_devices());
151   TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
152 
153   std::unique_ptr<Service> service(
154       new Service(options, std::move(execute_backend)));
155   return std::move(service);
156 }
157 
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)158 Service::Service(const ServiceOptions& options,
159                  std::unique_ptr<Backend> execute_backend)
160     : options_(options),
161       allocation_tracker_(execute_backend.get()),
162       execute_backend_(std::move(execute_backend)) {
163   CHECK_GT(options_.number_of_replicas(), 0);
164   if (execute_backend_) {
165     if (execute_backend_->device_count() > 0) {
166       CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
167           << "Requested more replicas than there are devices.";
168     }
169     LOG(INFO) << StrFormat(
170         "XLA service %p initialized for platform %s (this does not guarantee "
171         "that XLA will be used). Devices:",
172         this, execute_backend_->platform()->Name());
173     auto stream_executors = execute_backend_->stream_executors();
174     for (int i = 0; i < execute_backend_->device_count(); ++i) {
175       se::StreamExecutor* executor = stream_executors.at(i);
176       const auto& description = executor->GetDeviceDescription();
177       LOG(INFO) << StrFormat("  StreamExecutor device (%d): %s, %s", i,
178                              description.name(),
179                              description.platform_version());
180     }
181   } else {
182     VLOG(1) << "XLA compile-only service constructed";
183   }
184 }
185 
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)186 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
187                                     CreateChannelHandleResponse* result) {
188   TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
189                       channel_tracker_.NewChannel(arg->channel_type()));
190   return Status::OK();
191 }
192 
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)193 Status Service::Unregister(const UnregisterRequest* arg,
194                            UnregisterResponse* result) {
195   Status status;
196   for (auto& data : arg->data()) {
197     Status unregister_status = allocation_tracker_.Unregister(data);
198     if (!unregister_status.ok() && status.ok()) {
199       status = unregister_status;
200     }
201   }
202   return status;
203 }
204 
205 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)206 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
207                                  DeconstructTupleResponse* result) {
208   TF_ASSIGN_OR_RETURN(
209       std::vector<GlobalDataHandle> elements,
210       allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
211 
212   for (auto& element : elements) {
213     *result->add_element_handles() = element;
214   }
215   return Status::OK();
216 }
217 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const218 Status Service::ValidateResultShape(const Shape& client_shape,
219                                     const Shape& result_shape) const {
220   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
221   if (!ShapeUtil::Compatible(client_shape, result_shape)) {
222     return InvalidArgument(
223         "Shape used to set computation result layout %s is not compatible "
224         "with result shape %s",
225         ShapeUtil::HumanStringWithLayout(client_shape),
226         ShapeUtil::HumanString(result_shape));
227   }
228   return Status::OK();
229 }
230 
231 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const232 Service::ResolveAndValidateArguments(
233     absl::Span<const GlobalDataHandle* const> arguments,
234     absl::Span<se::StreamExecutor* const> stream_executors) const {
235   CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
236   std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
237   replicated_arguments.resize(options_.number_of_replicas());
238   for (size_t i = 0; i < arguments.size(); ++i) {
239     auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
240     if (!buffer_status.ok()) {
241       return Status(buffer_status.status().code(),
242                     StrCat(buffer_status.status().error_message(), ", ",
243                            "failed to resolve allocation for parameter ", i));
244     }
245     auto replicated_buffers = buffer_status.ValueOrDie();
246     CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
247     for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
248       const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
249       replicated_arguments[replica].push_back(shaped_buffer);
250     }
251   }
252   return replicated_arguments;
253 }
254 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)255 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
256     const ProgramShape& program_shape,
257     absl::Span<const Shape* const> argument_shapes,
258     const ExecutionOptions* execution_options,
259     const AotCompilationOptions* aot_options) {
260   int default_num_replicas = options_.number_of_replicas();
261   absl::optional<int> num_threads;
262   if (execute_backend_ != nullptr &&
263       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
264     num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
265   }
266 
267   return xla::CreateModuleConfig(program_shape, argument_shapes,
268                                  execution_options, default_num_replicas,
269                                  num_threads, aot_options);
270 }
271 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)272 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
273     const ProgramShape& program_shape,
274     absl::Span<const ShapedBuffer* const> arguments,
275     const ExecutionOptions& execution_options,
276     const AotCompilationOptions* aot_options) {
277   std::vector<const Shape*> argument_shapes;
278   for (const auto* arg : arguments) {
279     argument_shapes.push_back(&arg->on_host_shape());
280   }
281   return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
282                             aot_options);
283 }
284 
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)285 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
286     const std::vector<const HloModuleProto*>& module_protos,
287     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
288     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
289     const Compiler::CompileOptions& options, bool run_backend_only) {
290   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
291 
292   // Dump computation proto state if flag is set.
293   std::vector<std::unique_ptr<HloProto>> hlo_protos;
294   for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
295     auto hlo_proto = absl::make_unique<HloProto>();
296     *hlo_proto->mutable_hlo_module() = *module_protos[i];
297     hlo_protos.push_back(std::move(hlo_proto));
298   }
299 
300   VLOG(1) << "Computations:";
301   for (const HloModuleProto* proto : module_protos) {
302     VLOG(1) << proto->name();
303   }
304 
305   CHECK_EQ(module_protos.size(), module_configs.size());
306   auto module_group =
307       absl::make_unique<HloModuleGroup>(module_protos[0]->name());
308   for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
309     const HloModuleProto* proto = module_protos[i];
310     const HloModuleConfig& config = *module_configs[i];
311     TF_ASSIGN_OR_RETURN(
312         auto module, CreateModuleFromProto(*proto, config, run_backend_only));
313     DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
314     module_group->push_back(std::move(module));
315   }
316 
317   std::vector<std::unique_ptr<Executable>> executables;
318   if (!run_backend_only) {
319     TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
320                                          std::move(module_group),
321                                          std::move(executors), options));
322   } else {
323     auto modules = module_group->ConsumeModules();
324     for (std::unique_ptr<HloModule>& module : modules) {
325       TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
326                           backend->compiler()->RunBackend(
327                               std::move(module), executors[0][0], options));
328       executables.push_back(std::move(executable));
329     }
330   }
331 
332   for (size_t i = 0; i < module_protos.size(); ++i) {
333     const auto& debug_opts = module_configs[i]->debug_options();
334     if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
335         debug_opts.xla_dump_hlo_snapshots()) {
336       executables[i]->set_hlo_proto(std::move(hlo_protos[i]));
337     }
338   }
339 
340   return std::move(executables);
341 }
342 
343 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)344 Service::ExecuteParallelAndRegisterResult(
345     absl::Span<Executable* const> executables,
346     absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
347     Backend* backend, absl::Span<const DeviceHandle> device_handles,
348     absl::Span<const string> result_tags, ExecutionProfile* profile) {
349   // Streams where the computation are launched, so we can wait on the streams
350   // to complete.
351   std::vector<StreamPool::Ptr> streams;
352   std::vector<std::unique_ptr<se::Timer>> timers;
353 
354   // Global data handles for the computation results, one for each computation.
355   std::vector<GlobalDataHandle> result_handles;
356 
357   // Device ID to stream executor, populated only with devices that are being
358   // profiled.
359   std::map<int64, se::Stream*> index_to_profiled_streams;
360 
361   // Build DeviceAssignment for all cores based on the provided device handles.
362   DeviceAssignment device_assignment(options_.number_of_replicas(),
363                                      executables.size());
364   for (int64 i = 0; i < executables.size(); i++) {
365     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
366     CHECK_EQ(replicas.size(), arguments[i].size());
367     for (int64 replica = 0, end = replicas.size(); replica < end; ++replica) {
368       device_assignment(replica, i) = replicas[replica]->device_ordinal();
369     }
370   }
371 
372   for (int64 i = 0, end = executables.size(); i < end; i++) {
373     // Stream executors for the replicas of the current computation.
374     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
375     CHECK_EQ(replicas.size(), arguments[i].size());
376     std::vector<ScopedShapedBuffer> result_buffers;
377     for (int64 replica = 0; replica < replicas.size(); ++replica) {
378       TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
379                           backend->BorrowStream(replicas[replica]));
380       streams.push_back(std::move(stream));
381 
382       if (replica == 0 && profile != nullptr) {
383         timers.push_back(
384             absl::make_unique<se::Timer>(streams.back()->parent()));
385         streams.back()
386             ->InitTimer(timers.back().get())
387             .ThenStartTimer(timers.back().get());
388         CHECK(timers.front() != nullptr);
389       }
390 
391       if (replica == 0 &&
392           executables[i]->module_config().debug_options().xla_hlo_profile() &&
393           executables[i]->hlo_profiling_enabled()) {
394         index_to_profiled_streams[i] = streams.back().get();
395       }
396 
397       // Set up run options.
398       ExecutableRunOptions options;
399       options.set_stream(streams.back().get());
400       options.set_allocator(backend->memory_allocator());
401       options.set_intra_op_thread_pool(
402           backend->eigen_intra_op_thread_pool_device());
403       options.set_device_assignment(&device_assignment);
404       // Use run-time profile information from execution_profile on the 0th
405       // device.
406       if (i == 0) {
407         options.set_execution_profile(profile);
408       }
409       ServiceExecutableRunOptions run_options(options,
410                                               backend->StreamBorrower());
411 
412       // Asynchronously launch the computation.
413       TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
414                           executables[i]->ExecuteAsyncOnStream(
415                               &run_options, arguments[i][replica],
416                               /*hlo_execution_profile=*/nullptr));
417 
418       if (replica == 0 && profile != nullptr) {
419         streams.back()->ThenStopTimer(timers.back().get());
420       }
421 
422       result_buffers.push_back(std::move(result));
423     }
424     TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
425                         allocation_tracker_.RegisterReplicatedBuffers(
426                             std::move(result_buffers), result_tags[i]));
427     result_handles.push_back(handle);
428   }
429 
430   // Wait for all executions to complete.
431   for (int64 i = 0, end = streams.size(); i < end; ++i) {
432     Status block_status = streams[i]->BlockHostUntilDone();
433     if (!block_status.ok()) {
434       return InternalError("failed to complete execution for stream %d: %s", i,
435                            block_status.error_message());
436     }
437   }
438 
439   if (profile != nullptr) {
440     CHECK(!timers.empty());
441     std::vector<uint64> timer_nanoseconds;
442     timer_nanoseconds.reserve(timers.size());
443     for (auto& timer : timers) {
444       timer_nanoseconds.push_back(timer->Nanoseconds());
445     }
446     uint64 nanoseconds =
447         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
448 
449     // Overall execution time (in nanoseconds) from the executor timer.
450     profile->set_compute_and_transfer_time_ns(nanoseconds);
451 
452     // TODO(b/28123297): On GPU we end up including transfer time in
453     // the compute time this way. Instead, we should get the correct
454     // value by measuring it. Setting the field here at least lets
455     // benchmarks provide *some* value for GPU computations.
456     //
457     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
458     // the compute time without the transfer time, so this way we get the
459     // correct compute time. We should instead have the correct value for
460     // compute_and_transfer_time and set compute_time to the compute time.
461     if (profile->compute_time_ns() == 0) {
462       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
463     }
464   }
465 
466   return result_handles;
467 }
468 
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)469 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
470     Executable* executable,
471     absl::Span<const std::vector<const ShapedBuffer*>> arguments,
472     Backend* backend, const DeviceHandle& device_handle,
473     const string& result_tag, ExecutionProfile* profile) {
474   // Set up streams.
475   std::vector<StreamPool::Ptr> streams;
476 
477   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
478   TF_RET_CHECK(!replicas.empty());
479   for (se::StreamExecutor* executor : replicas) {
480     TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
481                         backend->BorrowStream(executor));
482     streams.push_back(std::move(stream));
483   }
484 
485   DeviceAssignment device_assignment(options_.number_of_replicas(),
486                                      /*computation_count=*/1);
487   for (int64 replica = 0; replica < replicas.size(); ++replica) {
488     device_assignment(replica, 0) = replicas[replica]->device_ordinal();
489   }
490 
491   // Set up run options.
492   std::vector<ServiceExecutableRunOptions> run_options;
493   for (const StreamPool::Ptr& stream : streams) {
494     ExecutableRunOptions options;
495     options.set_stream(stream.get());
496     options.set_device_ordinal(stream->parent()->device_ordinal());
497     options.set_allocator(backend->memory_allocator());
498     options.set_intra_op_thread_pool(
499         backend->eigen_intra_op_thread_pool_device());
500     options.set_device_assignment(&device_assignment);
501     options.set_execution_profile(profile);
502     run_options.emplace_back(options, backend->StreamBorrower());
503   }
504 
505   if (options_.number_of_replicas() == 1) {
506     TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
507                                          &run_options[0], arguments[0]));
508     return allocation_tracker_.Register(std::move(result), result_tag);
509   }
510 
511   // TODO(b/69985541): Support profiling also on this path.
512 
513   std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
514   for (const auto& arg : arguments) {
515     replicated_arguments.push_back(arg);
516   }
517 
518   TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
519                                         run_options, replicated_arguments));
520   TF_RET_CHECK(!results.empty());
521   return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
522                                                        result_tag);
523 }
524 
GetExecutors(const ExecutionOptions & execution_options,int64 requests_size,int64 request_index) const525 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
526     const ExecutionOptions& execution_options, int64 requests_size,
527     int64 request_index) const {
528   if (execution_options.device_handles().empty()) {
529     return FailedPrecondition(
530         "device handles must be given to execute parallel computations");
531   }
532   if (requests_size > 1 && execution_options.device_handles_size() > 1) {
533     return InvalidArgument(
534         "Parallel requests with multiple device handles is not supported. "
535         "Found %d parallel requests, with request %d containing %d device "
536         "handles.",
537         requests_size, request_index, execution_options.device_handles_size());
538   }
539   std::vector<se::StreamExecutor*> executors;
540   for (const auto& device_handle : execution_options.device_handles()) {
541     TF_ASSIGN_OR_RETURN(auto replicas,
542                         Replicas(*execute_backend_, device_handle));
543     se::StreamExecutor* executor = replicas[0];
544     CHECK(executor != nullptr);
545     executors.push_back(executor);
546   }
547   return executors;
548 }
549 
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const550 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
551     const ExecutionOptions& execution_options,
552     absl::Span<const GlobalDataHandle* const> arguments) const {
553   // Resolve the allocations for the arguments of the computation, and create
554   // a vector of device memory offsets for the arguments from the allocations.
555   // In the case of partitioned computations, assume all arguments go on the
556   // zeroth core.
557   TF_ASSIGN_OR_RETURN(
558       auto replicas,
559       Replicas(*execute_backend_, execution_options.device_handles(0)));
560   TF_ASSIGN_OR_RETURN(
561       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
562       ResolveAndValidateArguments(arguments, replicas));
563   return replicated_arguments;
564 }
565 
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)566 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
567                                      ExecuteParallelResponse* result) {
568   VLOG(1) << "running execute-graph-parallel request";
569 
570   std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
571   std::vector<std::vector<se::StreamExecutor*>> all_executors;
572   std::vector<const HloModuleProto*> module_protos;
573   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
574   std::vector<string> computation_names;
575   std::vector<DeviceHandle> device_handles;
576 
577   int num_requested_devices =
578       std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
579                       [](int a, const ExecuteGraphRequest& r) -> int {
580                         return a + r.execution_options().device_handles_size();
581                       });
582   if (num_requested_devices * options_.number_of_replicas() >
583       execute_backend_->device_count()) {
584     return FailedPrecondition(
585         "there are not enough stream executors to execute %d computations",
586         num_requested_devices);
587   }
588 
589   for (int64 i = 0; i < arg->requests_size(); ++i) {
590     // Get the stream executor for the i'th computation. This stream executor
591     // is one of the executors to run the replicated computation.
592     const ExecutionOptions& execution_options =
593         arg->requests(i).execution_options();
594     const ExecuteGraphRequest& request = arg->requests(i);
595     TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
596     TF_RET_CHECK(request.computation().has_host_program_shape())
597         << "program shape may not be empty";
598 
599     // Get the executors.
600     TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
601                                                      arg->requests_size(), i));
602 
603     // Get the replicated arguments.
604     TF_ASSIGN_OR_RETURN(auto replicated_arguments,
605                         GetArguments(execution_options, request.arguments()));
606 
607     // Create an HloModuleConfig object for the computation, given the shape of
608     // the program and the argument allocations. Here, we care only about the
609     // shapes of the arguments, so, it is sufficient to use the arguments of
610     // replica 0.
611     TF_ASSIGN_OR_RETURN(
612         std::unique_ptr<HloModuleConfig> module_config,
613         CreateModuleConfig(
614             ProgramShape{request.computation().host_program_shape()},
615             replicated_arguments.front(), request.execution_options()));
616     VLOG(3)
617         << "ExecuteGraphParallel created HloModuleConfig computation layout: "
618         << module_config->entry_computation_layout().ToString();
619 
620     // Adds to the vectors to build and execute the computations after the loop.
621     all_arguments.push_back(replicated_arguments);
622     all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
623     module_protos.push_back(&request.computation());
624     module_configs.push_back(std::move(module_config));
625     computation_names.insert(computation_names.end(), executors.size(),
626                              request.computation().name());
627     all_executors.push_back(executors);
628     device_handles.insert(device_handles.end(),
629                           execution_options.device_handles().begin(),
630                           execution_options.device_handles().end());
631   }
632 
633   // Build the HloModules and compile to generate the executables.
634   //
635   // TODO(jlebar): There's currently no way to pass a device allocator to
636   // ExecuteGraphParallel, so we have to pass a null device_allocator below.
637   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
638                       BuildExecutables(module_protos, std::move(module_configs),
639                                        execute_backend_.get(), all_executors,
640                                        {/*device_allocator=*/nullptr}));
641   std::vector<Executable*> executable_ptrs;
642   executable_ptrs.reserve(executables.size());
643   for (const auto& executable : executables) {
644     executable_ptrs.push_back(executable.get());
645   }
646 
647   std::vector<HloSnapshot> snapshots;
648   snapshots.resize(executable_ptrs.size());
649   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
650     if (executable_ptrs[i]->dumping_snapshot()) {
651       *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
652       TF_ASSIGN_OR_RETURN(auto stream,
653                           execute_backend_->BorrowStream(
654                               all_executors[i][0]->device_ordinal()));
655       TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
656                                          execute_backend_->transfer_manager(),
657                                          &snapshots[i]));
658     }
659   }
660 
661   // If we have multiple executables to run, execute them all in parallel.  But
662   // if we only have one executable, execute it using the vanilla, non-parallel
663   // call.
664   //
665   // We do this because the Client API uses ExecuteGraphParallel when it wants
666   // to compile and run one computation without caching the executable, but not
667   // all backends support the async StreamExecutor API required by
668   // ExecuteParallelAndRegisterResult.
669   //
670   // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
671   // basically the same thing.
672   ExecutionProfile profile;
673   std::vector<GlobalDataHandle> outputs;
674   Status execution_status = Status::OK();
675 
676   if (executable_ptrs.size() == 1) {
677     StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
678         executable_ptrs[0], all_arguments[0], execute_backend_.get(),
679         device_handles[0], computation_names[0], &profile);
680     if (output_or_status.ok()) {
681       outputs.push_back(std::move(output_or_status).ValueOrDie());
682     } else {
683       execution_status = output_or_status.status();
684     }
685   } else {
686     StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
687         ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
688                                          execute_backend_.get(), device_handles,
689                                          computation_names, &profile);
690     if (outputs_or_status.ok()) {
691       outputs = std::move(outputs_or_status).ValueOrDie();
692     } else {
693       execution_status = outputs_or_status.status();
694     }
695   }
696 
697   if (!execution_status.ok()) {
698     // Execution failed so we don't have the results.  Dump the HLO snapshot
699     // with just the program arguments.
700     for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
701       DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
702     }
703   }
704 
705   TF_RETURN_IF_ERROR(execution_status);
706 
707   for (const GlobalDataHandle& output : outputs) {
708     ExecuteResponse response;
709     *response.mutable_output() = output;
710     *response.mutable_profile() = profile;
711     *result->add_responses() = response;
712   }
713 
714   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
715     Executable* executable = executable_ptrs[i];
716     if (executable->dumping_snapshot()) {
717       TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
718                           allocation_tracker_.ResolveForReplica(outputs[i], 0));
719       TF_ASSIGN_OR_RETURN(auto stream,
720                           execute_backend_->BorrowStream(all_executors[i][0]));
721       TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
722                                       execute_backend_->transfer_manager(),
723                                       &snapshots[i]));
724       DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
725     }
726   }
727 
728   VLOG(1) << "successfully completed 'execute-graph-parallel' request";
729   return Status::OK();
730 }
731 
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)732 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
733                                  GetDeviceHandlesResponse* result) {
734   const int64 available_device_count = execute_backend_->device_count();
735   const int64 replica_count = options_.number_of_replicas();
736   if (replica_count <= 0) {
737     return FailedPrecondition("Replica count must be a positive integer");
738   }
739   if (available_device_count < arg->device_count() * replica_count) {
740     return ResourceExhausted(
741         "Requested logical device count (%d) with replica count (%d) exceeds "
742         "the number of available physical devices on the target (%d)",
743         arg->device_count(), replica_count, available_device_count);
744   }
745 
746   for (int64 i = 0; i < arg->device_count(); ++i) {
747     DeviceHandle device_handle;
748     device_handle.set_handle(i);
749     device_handle.set_device_count(arg->device_count());
750     *result->add_device_handles() = device_handle;
751   }
752 
753   return Status::OK();
754 }
755 
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)756 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
757     const HloModuleProto& module_proto,
758     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
759     se::StreamExecutor* executor, const Compiler::CompileOptions& options,
760     bool run_backend_only) {
761   VLOG(1) << StrFormat(
762       "BuildExecutable on service %p with serialized module proto: %s", this,
763       module_proto.name());
764 
765   TF_ASSIGN_OR_RETURN(
766       std::unique_ptr<HloModule> module,
767       CreateModuleFromProto(module_proto, *module_config, run_backend_only));
768   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
769 
770   if (!run_backend_only) {
771     TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
772                                     std::move(module), executor, options));
773   }
774 
775   TF_ASSIGN_OR_RETURN(
776       std::unique_ptr<Executable> executable,
777       backend->compiler()->RunBackend(std::move(module), executor, options));
778 
779   const auto& debug_opts = module_config->debug_options();
780   if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
781       debug_opts.xla_dump_hlo_snapshots()) {
782     auto hlo_proto = absl::make_unique<HloProto>();
783     *hlo_proto->mutable_hlo_module() = module_proto;
784     executable->set_hlo_proto(std::move(hlo_proto));
785   }
786 
787   return std::move(executable);
788 }
789 
Compile(const CompileRequest * arg,CompileResponse * result)790 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
791   VLOG(1) << "running compile request";
792   if (!arg->has_computation()) {
793     return InvalidArgument("computations may not be empty");
794   }
795   if (!arg->computation().has_host_program_shape()) {
796     return InvalidArgument("program shape may not be empty");
797   }
798 
799   if (arg->execution_options().device_handles_size() > 1) {
800     return InvalidArgument(
801         "The compile request does not support multiple device handles.");
802   }
803 
804   std::vector<Shape> argument_shapes;
805   argument_shapes.reserve(arg->input_shape_with_layout_size());
806   std::vector<const Shape*> argument_shape_ptrs;
807   for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
808     argument_shapes.push_back(Shape(shape_proto));
809     argument_shape_ptrs.push_back(&argument_shapes.back());
810   }
811   TF_ASSIGN_OR_RETURN(
812       std::unique_ptr<HloModuleConfig> module_config,
813       CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
814                          argument_shape_ptrs, &arg->execution_options()));
815   VLOG(3) << "Compile created HloModuleConfig computation layout: "
816           << module_config->entry_computation_layout().ToString();
817 
818   TF_ASSIGN_OR_RETURN(
819       std::unique_ptr<Executable> executable,
820       BuildExecutable(arg->computation(), std::move(module_config),
821                       execute_backend_.get(),
822                       execute_backend_->default_stream_executor(),
823                       {/*device_allocator=*/nullptr}));
824 
825   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
826 
827   VLOG(1) << "successfully completed 'compile' request";
828   return Status::OK();
829 }
830 
Execute(const ExecuteRequest * arg,ExecuteResponse * result)831 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
832   VLOG(1) << "running execute request";
833   if (!arg->has_handle()) {
834     return InvalidArgument("execution handle should not be empty");
835   }
836   TF_ASSIGN_OR_RETURN(auto executable,
837                       compilation_cache_.LookUp(arg->handle()));
838 
839   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
840                                               SingleComputationDeviceHandle()));
841   TF_ASSIGN_OR_RETURN(
842       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
843       ResolveAndValidateArguments(arg->arguments(), replicas));
844 
845   // Check that the replicated_arguments has the same shape and layout as the
846   // module config used when creating the executable.
847   const int64 num_module_args =
848       executable->module_config().entry_computation_layout().parameter_count();
849   if (num_module_args != arg->arguments_size()) {
850     return InvalidArgument(
851         "The executable expects %lld arguments, but sees %lld.",
852         num_module_args, arg->arguments_size());
853   }
854   for (int64 i = 0; i < num_module_args; i++) {
855     const Shape& shape_module =
856         executable->module_config().entry_computation_layout().parameter_shape(
857             i);
858     const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
859     if (!ShapeUtil::Equal(shape_module, shape_arg)) {
860       return InvalidArgumentStrCat(
861           "The executable expects the ", i, "th argument in shape ",
862           ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
863           ShapeUtil::HumanStringWithLayout(shape_arg));
864     }
865   }
866 
867   TF_ASSIGN_OR_RETURN(auto stream,
868                       execute_backend_->BorrowStream(
869                           execute_backend_->default_stream_executor()));
870   HloSnapshot snapshot;
871   if (executable->dumping_snapshot()) {
872     *snapshot.mutable_hlo() = *executable->hlo_proto();
873     snapshot.set_execution_platform(execute_backend_->platform()->Name());
874     TF_RETURN_IF_ERROR(
875         RecordArguments(replicated_arguments.front(), stream.get(),
876                         execute_backend_->transfer_manager(), &snapshot));
877   }
878 
879   TF_ASSIGN_OR_RETURN(
880       *result->mutable_output(),
881       ExecuteAndRegisterResult(executable.get(), replicated_arguments,
882                                execute_backend_.get(),
883                                SingleComputationDeviceHandle(),
884                                "result of " + executable->module().name(),
885                                result->mutable_profile()));
886 
887   if (executable->dumping_snapshot()) {
888     TF_ASSIGN_OR_RETURN(
889         const ShapedBuffer* result_buffer,
890         allocation_tracker_.ResolveForReplica(result->output(), 0));
891     TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
892                                     execute_backend_->transfer_manager(),
893                                     &snapshot));
894     DumpHloSnapshotIfEnabled(executable->module(), snapshot);
895   }
896 
897   VLOG(1) << "successfully completed 'execute' request";
898   return Status::OK();
899 }
900 
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)901 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
902                                  WaitForExecutionResponse* result) {
903   TF_ASSIGN_OR_RETURN(const auto execution,
904                       execution_tracker_.Resolve(arg->execution()));
905 
906   TF_RETURN_IF_ERROR(execution->BlockUntilDone());
907 
908   *result->mutable_output() = execution->result();
909   *result->mutable_profile() = execution->profile();
910 
911   TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
912   VLOG(1) << "successfully completed 'wait-for-execution' request";
913   return Status::OK();
914 }
915 
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)916 Status Service::TransferToClient(const TransferToClientRequest* arg,
917                                  TransferToClientResponse* result) {
918   TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
919                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
920 
921   Shape return_shape;
922   if (arg->has_shape_with_layout()) {
923     return_shape = Shape(arg->shape_with_layout());
924     if (!LayoutUtil::HasLayout(return_shape)) {
925       return InvalidArgument("shape_with_layout must have layout if present.");
926     }
927   } else {
928     return_shape = Shape(shaped_buffer->on_host_shape());
929   }
930 
931   TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
932                                        shaped_buffer->device_ordinal()));
933 
934   TF_ASSIGN_OR_RETURN(
935       Literal result_literal,
936       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
937           stream.get(), *shaped_buffer));
938 
939   if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
940     *result->mutable_literal() = result_literal.ToProto();
941   } else {
942     *result->mutable_literal() =
943         result_literal.Relayout(return_shape).ToProto();
944   }
945   return Status::OK();
946 }
947 
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)948 Status Service::TransferToServer(const TransferToServerRequest* arg,
949                                  TransferToServerResponse* result) {
950   TF_ASSIGN_OR_RETURN(Literal literal,
951                       Literal::CreateFromProto(arg->literal()));
952   const Shape& shape = literal.shape();
953 
954   std::vector<se::StreamExecutor*> replicas;
955   if (arg->has_device_handle()) {
956     TF_ASSIGN_OR_RETURN(replicas,
957                         Replicas(*execute_backend_, arg->device_handle()));
958   } else {
959     TF_ASSIGN_OR_RETURN(
960         replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
961   }
962 
963   // Allocate memory in each replica and transfer the data to all replicas.
964   std::vector<ScopedShapedBuffer> replicated_buffers;
965   for (se::StreamExecutor* executor : replicas) {
966     TF_ASSIGN_OR_RETURN(
967         ScopedShapedBuffer shaped_buffer,
968         execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
969             shape, execute_backend_->memory_allocator(),
970             executor->device_ordinal()));
971     TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
972     TF_RETURN_IF_ERROR(
973         execute_backend_->transfer_manager()->TransferLiteralToDevice(
974             stream.get(), literal, shaped_buffer));
975     replicated_buffers.emplace_back(std::move(shaped_buffer));
976   }
977   TF_ASSIGN_OR_RETURN(*result->mutable_data(),
978                       allocation_tracker_.RegisterReplicatedBuffers(
979                           std::move(replicated_buffers),
980                           StrCat("TransferToServer literal of shape ",
981                                  ShapeUtil::HumanString(shape))));
982 
983   return Status::OK();
984 }
985 
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)986 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
987                                  TransferToInfeedResponse* result) {
988   const int64 replica_count = options_.number_of_replicas();
989   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
990     return FailedPrecondition(
991         "%s",
992         StrCat("The replica_id=", arg->replica_id(),
993                " on TransferToInfeedRequest not in range [0, replica_count=",
994                replica_count, ")."));
995   }
996 
997   se::StreamExecutor* executor;
998   if (arg->has_device_handle()) {
999     TF_ASSIGN_OR_RETURN(auto replicas,
1000                         Replicas(*execute_backend_, arg->device_handle()));
1001     executor = replicas[arg->replica_id()];
1002   } else {
1003     TF_ASSIGN_OR_RETURN(
1004         auto replicas,
1005         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1006     executor = replicas[arg->replica_id()];
1007   }
1008 
1009   TF_ASSIGN_OR_RETURN(Literal literal,
1010                       Literal::CreateFromProto(arg->literal()));
1011   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1012                                                                        literal);
1013 }
1014 
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1015 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1016                                     TransferFromOutfeedResponse* result) {
1017   const int64 replica_count = options_.number_of_replicas();
1018   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1019     return FailedPrecondition(
1020         "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1021         arg->replica_id(), replica_count);
1022   }
1023 
1024   se::StreamExecutor* executor;
1025   if (arg->has_device_handle()) {
1026     TF_ASSIGN_OR_RETURN(auto replicas,
1027                         Replicas(*execute_backend_, arg->device_handle()));
1028     executor = replicas[arg->replica_id()];
1029   } else {
1030     TF_ASSIGN_OR_RETURN(
1031         auto replicas,
1032         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1033     executor = replicas[arg->replica_id()];
1034   }
1035 
1036   auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1037 
1038   TF_RETURN_IF_ERROR(
1039       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1040           executor, &literal));
1041   *result->mutable_literal() = literal.ToProto();
1042   return Status::OK();
1043 }
1044 
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1045 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1046                             ResetDeviceResponse* result) {
1047   return execute_backend_->ResetDevices();
1048 }
1049 
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1050 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1051                                      ComputeConstantResponse* result) {
1052   if (!arg->has_computation()) {
1053     return InvalidArgument("computations may not be empty");
1054   }
1055   if (!arg->computation().has_host_program_shape()) {
1056     return InvalidArgument("program shape may not be empty");
1057   }
1058   if (arg->computation().host_program_shape().parameters_size() != 0) {
1059     return InvalidArgument(
1060         "constant computation may not depend on any parameters.");
1061   }
1062 
1063   ProgramShape program_shape(arg->computation().host_program_shape());
1064   TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1065   absl::optional<Layout> output_layout;
1066   if (arg->has_output_layout()) {
1067     output_layout = Layout::CreateFromProto(arg->output_layout());
1068     TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1069         *output_layout, program_shape.result()));
1070   }
1071 
1072   HloModuleConfig config(program_shape);
1073 
1074   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1075                       CreateModuleFromProto(arg->computation(), config));
1076 
1077   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1078                       DynamicDimensionInference::Run(module.get()));
1079 
1080   HloEvaluator evaluator;
1081   evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1082   TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1083 
1084   // Since the result layout is non-effective to the Evaluator results, explicit
1085   // relayout here.
1086   //
1087   // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1088   if (output_layout.has_value()) {
1089     result_literal = result_literal.Relayout(*output_layout);
1090   }
1091   *result->mutable_literal() = result_literal.ToProto();
1092 
1093   return Status::OK();
1094 }
1095 
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1096 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1097   TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1098                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
1099   *result->mutable_shape() = buffer->on_host_shape().ToProto();
1100   return Status::OK();
1101 }
1102 
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1103 Status Service::GetComputationGraphStats(
1104     const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1105   if (!arg->has_computation()) {
1106     return InvalidArgument("Computations may not be empty.");
1107   }
1108   if (!arg->computation().has_host_program_shape()) {
1109     return InvalidArgument("Program shape may not be empty.");
1110   }
1111 
1112   HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1113   config.set_debug_options(arg->debug_options());
1114   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1115                       CreateModuleFromProto(arg->computation(), config));
1116   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1117 
1118   // Run HLO analysis to get the computation statistics.
1119   HloCostAnalysis analysis(
1120       execute_backend_->compiler()->ShapeSizeBytesFunction());
1121 
1122   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1123 
1124   ComputationStats stats;
1125   stats.set_flop_count(analysis.flop_count());
1126   stats.set_transcendental_count(analysis.transcendental_count());
1127   *result->mutable_stats() = stats;
1128   return Status::OK();
1129 }
1130 
SingleComputationDeviceHandle() const1131 DeviceHandle Service::SingleComputationDeviceHandle() const {
1132   DeviceHandle device_handle;
1133   device_handle.set_handle(0);
1134   device_handle.set_device_count(1);
1135   return device_handle;
1136 }
1137 
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1138 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1139     const Backend& backend, const DeviceHandle& device_handle) const {
1140   std::vector<se::StreamExecutor*> replicas;
1141   for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1142     // From the computation placer, find out the device ids of the replicas for
1143     // the given device handle.
1144     TF_ASSIGN_OR_RETURN(
1145         int device_ordinal,
1146         backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1147                                                options_.number_of_replicas(),
1148                                                device_handle.device_count()));
1149     TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1150     replicas.push_back(executor);
1151   }
1152   return replicas;
1153 }
1154 
1155 }  // namespace xla
1156