1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/service.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/dump.h"
33 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
34 #include "tensorflow/compiler/xla/service/executable.h"
35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
36 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
37 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_module.h"
40 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
42 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
43 #include "tensorflow/compiler/xla/service/platform_util.h"
44 #include "tensorflow/compiler/xla/service/source_map_util.h"
45 #include "tensorflow/compiler/xla/service/stream_pool.h"
46 #include "tensorflow/compiler/xla/service/transfer_manager.h"
47 #include "tensorflow/compiler/xla/shape.h"
48 #include "tensorflow/compiler/xla/shape_layout.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/gtl/cleanup.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
59 #include "tensorflow/core/platform/types.h"
60 #include "tensorflow/core/util/ptr_util.h"
61 #include "tensorflow/stream_executor/device_memory_allocator.h"
62
63 namespace xla {
64 namespace {
65
66 using absl::StrCat;
67 using absl::StrFormat;
68
69 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
70 // performed on an HloModule.
71 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
72
73 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)74 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
75 se::Stream* stream, TransferManager* transfer_manager,
76 HloSnapshot* module) {
77 module->clear_arguments();
78 for (const ShapedBuffer* argument : arguments) {
79 TF_ASSIGN_OR_RETURN(
80 Literal literal,
81 transfer_manager->TransferLiteralFromDevice(stream, *argument));
82 *module->add_arguments() = literal.ToProto();
83 }
84 return Status::OK();
85 }
86
87 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)88 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
89 TransferManager* transfer_manager, HloSnapshot* module) {
90 module->clear_result();
91 TF_ASSIGN_OR_RETURN(
92 Literal literal,
93 transfer_manager->TransferLiteralFromDevice(stream, result));
94 *module->mutable_result() = literal.ToProto();
95 return Status::OK();
96 }
97
98 } // namespace
99
set_platform(se::Platform * platform)100 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
101 platform_ = platform;
102 return *this;
103 }
104
platform() const105 se::Platform* ServiceOptions::platform() const { return platform_; }
106
set_number_of_replicas(int number_of_replicas)107 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
108 number_of_replicas_ = number_of_replicas;
109 return *this;
110 }
111
number_of_replicas() const112 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
113
set_intra_op_parallelism_threads(int num_threads)114 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
115 int num_threads) {
116 intra_op_parallelism_threads_ = num_threads;
117 return *this;
118 }
119
intra_op_parallelism_threads() const120 int ServiceOptions::intra_op_parallelism_threads() const {
121 return intra_op_parallelism_threads_;
122 }
123
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)124 ServiceOptions& ServiceOptions::set_allowed_devices(
125 const absl::optional<std::set<int>>& allowed_devices) {
126 allowed_devices_ = allowed_devices;
127 return *this;
128 }
129
allowed_devices() const130 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
131 return allowed_devices_;
132 }
133
NewService(se::Platform * platform)134 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
135 se::Platform* platform) {
136 ServiceOptions default_options;
137 default_options.set_platform(platform);
138 return NewService(default_options);
139 }
140
NewService(const ServiceOptions & options)141 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
142 const ServiceOptions& options) {
143 se::Platform* platform = options.platform();
144 std::unique_ptr<Backend> execute_backend;
145 if (platform == nullptr) {
146 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
147 }
148 BackendOptions backend_options;
149 backend_options.set_platform(platform);
150 backend_options.set_allowed_devices(options.allowed_devices());
151 TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
152
153 std::unique_ptr<Service> service(
154 new Service(options, std::move(execute_backend)));
155 return std::move(service);
156 }
157
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)158 Service::Service(const ServiceOptions& options,
159 std::unique_ptr<Backend> execute_backend)
160 : options_(options),
161 allocation_tracker_(execute_backend.get()),
162 execute_backend_(std::move(execute_backend)) {
163 CHECK_GT(options_.number_of_replicas(), 0);
164 if (execute_backend_) {
165 if (execute_backend_->device_count() > 0) {
166 CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
167 << "Requested more replicas than there are devices.";
168 }
169 LOG(INFO) << StrFormat(
170 "XLA service %p initialized for platform %s (this does not guarantee "
171 "that XLA will be used). Devices:",
172 this, execute_backend_->platform()->Name());
173 auto stream_executors = execute_backend_->stream_executors();
174 for (int i = 0; i < execute_backend_->device_count(); ++i) {
175 se::StreamExecutor* executor = stream_executors.at(i);
176 const auto& description = executor->GetDeviceDescription();
177 LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
178 description.name(),
179 description.platform_version());
180 }
181 } else {
182 VLOG(1) << "XLA compile-only service constructed";
183 }
184 }
185
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)186 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
187 CreateChannelHandleResponse* result) {
188 TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
189 channel_tracker_.NewChannel(arg->channel_type()));
190 return Status::OK();
191 }
192
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)193 Status Service::Unregister(const UnregisterRequest* arg,
194 UnregisterResponse* result) {
195 Status status;
196 for (auto& data : arg->data()) {
197 Status unregister_status = allocation_tracker_.Unregister(data);
198 if (!unregister_status.ok() && status.ok()) {
199 status = unregister_status;
200 }
201 }
202 return status;
203 }
204
205 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)206 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
207 DeconstructTupleResponse* result) {
208 TF_ASSIGN_OR_RETURN(
209 std::vector<GlobalDataHandle> elements,
210 allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
211
212 for (auto& element : elements) {
213 *result->add_element_handles() = element;
214 }
215 return Status::OK();
216 }
217
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const218 Status Service::ValidateResultShape(const Shape& client_shape,
219 const Shape& result_shape) const {
220 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
221 if (!ShapeUtil::Compatible(client_shape, result_shape)) {
222 return InvalidArgument(
223 "Shape used to set computation result layout %s is not compatible "
224 "with result shape %s",
225 ShapeUtil::HumanStringWithLayout(client_shape),
226 ShapeUtil::HumanString(result_shape));
227 }
228 return Status::OK();
229 }
230
231 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const232 Service::ResolveAndValidateArguments(
233 absl::Span<const GlobalDataHandle* const> arguments,
234 absl::Span<se::StreamExecutor* const> stream_executors) const {
235 CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
236 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
237 replicated_arguments.resize(options_.number_of_replicas());
238 for (size_t i = 0; i < arguments.size(); ++i) {
239 auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
240 if (!buffer_status.ok()) {
241 return Status(buffer_status.status().code(),
242 StrCat(buffer_status.status().error_message(), ", ",
243 "failed to resolve allocation for parameter ", i));
244 }
245 auto replicated_buffers = buffer_status.ValueOrDie();
246 CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
247 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
248 const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
249 replicated_arguments[replica].push_back(shaped_buffer);
250 }
251 }
252 return replicated_arguments;
253 }
254
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)255 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
256 const ProgramShape& program_shape,
257 absl::Span<const Shape* const> argument_shapes,
258 const ExecutionOptions* execution_options,
259 const AotCompilationOptions* aot_options) {
260 int default_num_replicas = options_.number_of_replicas();
261 absl::optional<int> num_threads;
262 if (execute_backend_ != nullptr &&
263 execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
264 num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
265 }
266
267 return xla::CreateModuleConfig(program_shape, argument_shapes,
268 execution_options, default_num_replicas,
269 num_threads, aot_options);
270 }
271
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)272 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
273 const ProgramShape& program_shape,
274 absl::Span<const ShapedBuffer* const> arguments,
275 const ExecutionOptions& execution_options,
276 const AotCompilationOptions* aot_options) {
277 std::vector<const Shape*> argument_shapes;
278 for (const auto* arg : arguments) {
279 argument_shapes.push_back(&arg->on_host_shape());
280 }
281 return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
282 aot_options);
283 }
284
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)285 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
286 const std::vector<const HloModuleProto*>& module_protos,
287 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
288 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
289 const Compiler::CompileOptions& options, bool run_backend_only) {
290 VLOG(1) << StrFormat("BuildExecutable on service %p", this);
291
292 // Dump computation proto state if flag is set.
293 std::vector<std::unique_ptr<HloProto>> hlo_protos;
294 for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
295 auto hlo_proto = absl::make_unique<HloProto>();
296 *hlo_proto->mutable_hlo_module() = *module_protos[i];
297 hlo_protos.push_back(std::move(hlo_proto));
298 }
299
300 VLOG(1) << "Computations:";
301 for (const HloModuleProto* proto : module_protos) {
302 VLOG(1) << proto->name();
303 }
304
305 CHECK_EQ(module_protos.size(), module_configs.size());
306 auto module_group =
307 absl::make_unique<HloModuleGroup>(module_protos[0]->name());
308 for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
309 const HloModuleProto* proto = module_protos[i];
310 const HloModuleConfig& config = *module_configs[i];
311 TF_ASSIGN_OR_RETURN(
312 auto module, CreateModuleFromProto(*proto, config, run_backend_only));
313 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
314 module_group->push_back(std::move(module));
315 }
316
317 std::vector<std::unique_ptr<Executable>> executables;
318 if (!run_backend_only) {
319 TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
320 std::move(module_group),
321 std::move(executors), options));
322 } else {
323 auto modules = module_group->ConsumeModules();
324 for (std::unique_ptr<HloModule>& module : modules) {
325 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
326 backend->compiler()->RunBackend(
327 std::move(module), executors[0][0], options));
328 executables.push_back(std::move(executable));
329 }
330 }
331
332 for (size_t i = 0; i < module_protos.size(); ++i) {
333 const auto& debug_opts = module_configs[i]->debug_options();
334 if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
335 debug_opts.xla_dump_hlo_snapshots()) {
336 executables[i]->set_hlo_proto(std::move(hlo_protos[i]));
337 }
338 }
339
340 return std::move(executables);
341 }
342
343 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)344 Service::ExecuteParallelAndRegisterResult(
345 absl::Span<Executable* const> executables,
346 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
347 Backend* backend, absl::Span<const DeviceHandle> device_handles,
348 absl::Span<const string> result_tags, ExecutionProfile* profile) {
349 // Streams where the computation are launched, so we can wait on the streams
350 // to complete.
351 std::vector<StreamPool::Ptr> streams;
352 std::vector<std::unique_ptr<se::Timer>> timers;
353
354 // Global data handles for the computation results, one for each computation.
355 std::vector<GlobalDataHandle> result_handles;
356
357 // Device ID to stream executor, populated only with devices that are being
358 // profiled.
359 std::map<int64, se::Stream*> index_to_profiled_streams;
360
361 // Build DeviceAssignment for all cores based on the provided device handles.
362 DeviceAssignment device_assignment(options_.number_of_replicas(),
363 executables.size());
364 for (int64 i = 0; i < executables.size(); i++) {
365 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
366 CHECK_EQ(replicas.size(), arguments[i].size());
367 for (int64 replica = 0, end = replicas.size(); replica < end; ++replica) {
368 device_assignment(replica, i) = replicas[replica]->device_ordinal();
369 }
370 }
371
372 for (int64 i = 0, end = executables.size(); i < end; i++) {
373 // Stream executors for the replicas of the current computation.
374 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
375 CHECK_EQ(replicas.size(), arguments[i].size());
376 std::vector<ScopedShapedBuffer> result_buffers;
377 for (int64 replica = 0; replica < replicas.size(); ++replica) {
378 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
379 backend->BorrowStream(replicas[replica]));
380 streams.push_back(std::move(stream));
381
382 if (replica == 0 && profile != nullptr) {
383 timers.push_back(
384 absl::make_unique<se::Timer>(streams.back()->parent()));
385 streams.back()
386 ->InitTimer(timers.back().get())
387 .ThenStartTimer(timers.back().get());
388 CHECK(timers.front() != nullptr);
389 }
390
391 if (replica == 0 &&
392 executables[i]->module_config().debug_options().xla_hlo_profile() &&
393 executables[i]->hlo_profiling_enabled()) {
394 index_to_profiled_streams[i] = streams.back().get();
395 }
396
397 // Set up run options.
398 ExecutableRunOptions options;
399 options.set_stream(streams.back().get());
400 options.set_allocator(backend->memory_allocator());
401 options.set_intra_op_thread_pool(
402 backend->eigen_intra_op_thread_pool_device());
403 options.set_device_assignment(&device_assignment);
404 // Use run-time profile information from execution_profile on the 0th
405 // device.
406 if (i == 0) {
407 options.set_execution_profile(profile);
408 }
409 ServiceExecutableRunOptions run_options(options,
410 backend->StreamBorrower());
411
412 // Asynchronously launch the computation.
413 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
414 executables[i]->ExecuteAsyncOnStream(
415 &run_options, arguments[i][replica],
416 /*hlo_execution_profile=*/nullptr));
417
418 if (replica == 0 && profile != nullptr) {
419 streams.back()->ThenStopTimer(timers.back().get());
420 }
421
422 result_buffers.push_back(std::move(result));
423 }
424 TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
425 allocation_tracker_.RegisterReplicatedBuffers(
426 std::move(result_buffers), result_tags[i]));
427 result_handles.push_back(handle);
428 }
429
430 // Wait for all executions to complete.
431 for (int64 i = 0, end = streams.size(); i < end; ++i) {
432 Status block_status = streams[i]->BlockHostUntilDone();
433 if (!block_status.ok()) {
434 return InternalError("failed to complete execution for stream %d: %s", i,
435 block_status.error_message());
436 }
437 }
438
439 if (profile != nullptr) {
440 CHECK(!timers.empty());
441 std::vector<uint64> timer_nanoseconds;
442 timer_nanoseconds.reserve(timers.size());
443 for (auto& timer : timers) {
444 timer_nanoseconds.push_back(timer->Nanoseconds());
445 }
446 uint64 nanoseconds =
447 *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
448
449 // Overall execution time (in nanoseconds) from the executor timer.
450 profile->set_compute_and_transfer_time_ns(nanoseconds);
451
452 // TODO(b/28123297): On GPU we end up including transfer time in
453 // the compute time this way. Instead, we should get the correct
454 // value by measuring it. Setting the field here at least lets
455 // benchmarks provide *some* value for GPU computations.
456 //
457 // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
458 // the compute time without the transfer time, so this way we get the
459 // correct compute time. We should instead have the correct value for
460 // compute_and_transfer_time and set compute_time to the compute time.
461 if (profile->compute_time_ns() == 0) {
462 profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
463 }
464 }
465
466 return result_handles;
467 }
468
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)469 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
470 Executable* executable,
471 absl::Span<const std::vector<const ShapedBuffer*>> arguments,
472 Backend* backend, const DeviceHandle& device_handle,
473 const string& result_tag, ExecutionProfile* profile) {
474 // Set up streams.
475 std::vector<StreamPool::Ptr> streams;
476
477 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
478 TF_RET_CHECK(!replicas.empty());
479 for (se::StreamExecutor* executor : replicas) {
480 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
481 backend->BorrowStream(executor));
482 streams.push_back(std::move(stream));
483 }
484
485 DeviceAssignment device_assignment(options_.number_of_replicas(),
486 /*computation_count=*/1);
487 for (int64 replica = 0; replica < replicas.size(); ++replica) {
488 device_assignment(replica, 0) = replicas[replica]->device_ordinal();
489 }
490
491 // Set up run options.
492 std::vector<ServiceExecutableRunOptions> run_options;
493 for (const StreamPool::Ptr& stream : streams) {
494 ExecutableRunOptions options;
495 options.set_stream(stream.get());
496 options.set_device_ordinal(stream->parent()->device_ordinal());
497 options.set_allocator(backend->memory_allocator());
498 options.set_intra_op_thread_pool(
499 backend->eigen_intra_op_thread_pool_device());
500 options.set_device_assignment(&device_assignment);
501 options.set_execution_profile(profile);
502 run_options.emplace_back(options, backend->StreamBorrower());
503 }
504
505 if (options_.number_of_replicas() == 1) {
506 TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
507 &run_options[0], arguments[0]));
508 return allocation_tracker_.Register(std::move(result), result_tag);
509 }
510
511 // TODO(b/69985541): Support profiling also on this path.
512
513 std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
514 for (const auto& arg : arguments) {
515 replicated_arguments.push_back(arg);
516 }
517
518 TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
519 run_options, replicated_arguments));
520 TF_RET_CHECK(!results.empty());
521 return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
522 result_tag);
523 }
524
GetExecutors(const ExecutionOptions & execution_options,int64 requests_size,int64 request_index) const525 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
526 const ExecutionOptions& execution_options, int64 requests_size,
527 int64 request_index) const {
528 if (execution_options.device_handles().empty()) {
529 return FailedPrecondition(
530 "device handles must be given to execute parallel computations");
531 }
532 if (requests_size > 1 && execution_options.device_handles_size() > 1) {
533 return InvalidArgument(
534 "Parallel requests with multiple device handles is not supported. "
535 "Found %d parallel requests, with request %d containing %d device "
536 "handles.",
537 requests_size, request_index, execution_options.device_handles_size());
538 }
539 std::vector<se::StreamExecutor*> executors;
540 for (const auto& device_handle : execution_options.device_handles()) {
541 TF_ASSIGN_OR_RETURN(auto replicas,
542 Replicas(*execute_backend_, device_handle));
543 se::StreamExecutor* executor = replicas[0];
544 CHECK(executor != nullptr);
545 executors.push_back(executor);
546 }
547 return executors;
548 }
549
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const550 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
551 const ExecutionOptions& execution_options,
552 absl::Span<const GlobalDataHandle* const> arguments) const {
553 // Resolve the allocations for the arguments of the computation, and create
554 // a vector of device memory offsets for the arguments from the allocations.
555 // In the case of partitioned computations, assume all arguments go on the
556 // zeroth core.
557 TF_ASSIGN_OR_RETURN(
558 auto replicas,
559 Replicas(*execute_backend_, execution_options.device_handles(0)));
560 TF_ASSIGN_OR_RETURN(
561 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
562 ResolveAndValidateArguments(arguments, replicas));
563 return replicated_arguments;
564 }
565
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)566 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
567 ExecuteParallelResponse* result) {
568 VLOG(1) << "running execute-graph-parallel request";
569
570 std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
571 std::vector<std::vector<se::StreamExecutor*>> all_executors;
572 std::vector<const HloModuleProto*> module_protos;
573 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
574 std::vector<string> computation_names;
575 std::vector<DeviceHandle> device_handles;
576
577 int num_requested_devices =
578 std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
579 [](int a, const ExecuteGraphRequest& r) -> int {
580 return a + r.execution_options().device_handles_size();
581 });
582 if (num_requested_devices * options_.number_of_replicas() >
583 execute_backend_->device_count()) {
584 return FailedPrecondition(
585 "there are not enough stream executors to execute %d computations",
586 num_requested_devices);
587 }
588
589 for (int64 i = 0; i < arg->requests_size(); ++i) {
590 // Get the stream executor for the i'th computation. This stream executor
591 // is one of the executors to run the replicated computation.
592 const ExecutionOptions& execution_options =
593 arg->requests(i).execution_options();
594 const ExecuteGraphRequest& request = arg->requests(i);
595 TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
596 TF_RET_CHECK(request.computation().has_host_program_shape())
597 << "program shape may not be empty";
598
599 // Get the executors.
600 TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
601 arg->requests_size(), i));
602
603 // Get the replicated arguments.
604 TF_ASSIGN_OR_RETURN(auto replicated_arguments,
605 GetArguments(execution_options, request.arguments()));
606
607 // Create an HloModuleConfig object for the computation, given the shape of
608 // the program and the argument allocations. Here, we care only about the
609 // shapes of the arguments, so, it is sufficient to use the arguments of
610 // replica 0.
611 TF_ASSIGN_OR_RETURN(
612 std::unique_ptr<HloModuleConfig> module_config,
613 CreateModuleConfig(
614 ProgramShape{request.computation().host_program_shape()},
615 replicated_arguments.front(), request.execution_options()));
616 VLOG(3)
617 << "ExecuteGraphParallel created HloModuleConfig computation layout: "
618 << module_config->entry_computation_layout().ToString();
619
620 // Adds to the vectors to build and execute the computations after the loop.
621 all_arguments.push_back(replicated_arguments);
622 all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
623 module_protos.push_back(&request.computation());
624 module_configs.push_back(std::move(module_config));
625 computation_names.insert(computation_names.end(), executors.size(),
626 request.computation().name());
627 all_executors.push_back(executors);
628 device_handles.insert(device_handles.end(),
629 execution_options.device_handles().begin(),
630 execution_options.device_handles().end());
631 }
632
633 // Build the HloModules and compile to generate the executables.
634 //
635 // TODO(jlebar): There's currently no way to pass a device allocator to
636 // ExecuteGraphParallel, so we have to pass a null device_allocator below.
637 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
638 BuildExecutables(module_protos, std::move(module_configs),
639 execute_backend_.get(), all_executors,
640 {/*device_allocator=*/nullptr}));
641 std::vector<Executable*> executable_ptrs;
642 executable_ptrs.reserve(executables.size());
643 for (const auto& executable : executables) {
644 executable_ptrs.push_back(executable.get());
645 }
646
647 std::vector<HloSnapshot> snapshots;
648 snapshots.resize(executable_ptrs.size());
649 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
650 if (executable_ptrs[i]->dumping_snapshot()) {
651 *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
652 TF_ASSIGN_OR_RETURN(auto stream,
653 execute_backend_->BorrowStream(
654 all_executors[i][0]->device_ordinal()));
655 TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
656 execute_backend_->transfer_manager(),
657 &snapshots[i]));
658 }
659 }
660
661 // If we have multiple executables to run, execute them all in parallel. But
662 // if we only have one executable, execute it using the vanilla, non-parallel
663 // call.
664 //
665 // We do this because the Client API uses ExecuteGraphParallel when it wants
666 // to compile and run one computation without caching the executable, but not
667 // all backends support the async StreamExecutor API required by
668 // ExecuteParallelAndRegisterResult.
669 //
670 // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
671 // basically the same thing.
672 ExecutionProfile profile;
673 std::vector<GlobalDataHandle> outputs;
674 Status execution_status = Status::OK();
675
676 if (executable_ptrs.size() == 1) {
677 StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
678 executable_ptrs[0], all_arguments[0], execute_backend_.get(),
679 device_handles[0], computation_names[0], &profile);
680 if (output_or_status.ok()) {
681 outputs.push_back(std::move(output_or_status).ValueOrDie());
682 } else {
683 execution_status = output_or_status.status();
684 }
685 } else {
686 StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
687 ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
688 execute_backend_.get(), device_handles,
689 computation_names, &profile);
690 if (outputs_or_status.ok()) {
691 outputs = std::move(outputs_or_status).ValueOrDie();
692 } else {
693 execution_status = outputs_or_status.status();
694 }
695 }
696
697 if (!execution_status.ok()) {
698 // Execution failed so we don't have the results. Dump the HLO snapshot
699 // with just the program arguments.
700 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
701 DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
702 }
703 }
704
705 TF_RETURN_IF_ERROR(execution_status);
706
707 for (const GlobalDataHandle& output : outputs) {
708 ExecuteResponse response;
709 *response.mutable_output() = output;
710 *response.mutable_profile() = profile;
711 *result->add_responses() = response;
712 }
713
714 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
715 Executable* executable = executable_ptrs[i];
716 if (executable->dumping_snapshot()) {
717 TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
718 allocation_tracker_.ResolveForReplica(outputs[i], 0));
719 TF_ASSIGN_OR_RETURN(auto stream,
720 execute_backend_->BorrowStream(all_executors[i][0]));
721 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
722 execute_backend_->transfer_manager(),
723 &snapshots[i]));
724 DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
725 }
726 }
727
728 VLOG(1) << "successfully completed 'execute-graph-parallel' request";
729 return Status::OK();
730 }
731
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)732 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
733 GetDeviceHandlesResponse* result) {
734 const int64 available_device_count = execute_backend_->device_count();
735 const int64 replica_count = options_.number_of_replicas();
736 if (replica_count <= 0) {
737 return FailedPrecondition("Replica count must be a positive integer");
738 }
739 if (available_device_count < arg->device_count() * replica_count) {
740 return ResourceExhausted(
741 "Requested logical device count (%d) with replica count (%d) exceeds "
742 "the number of available physical devices on the target (%d)",
743 arg->device_count(), replica_count, available_device_count);
744 }
745
746 for (int64 i = 0; i < arg->device_count(); ++i) {
747 DeviceHandle device_handle;
748 device_handle.set_handle(i);
749 device_handle.set_device_count(arg->device_count());
750 *result->add_device_handles() = device_handle;
751 }
752
753 return Status::OK();
754 }
755
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)756 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
757 const HloModuleProto& module_proto,
758 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
759 se::StreamExecutor* executor, const Compiler::CompileOptions& options,
760 bool run_backend_only) {
761 VLOG(1) << StrFormat(
762 "BuildExecutable on service %p with serialized module proto: %s", this,
763 module_proto.name());
764
765 TF_ASSIGN_OR_RETURN(
766 std::unique_ptr<HloModule> module,
767 CreateModuleFromProto(module_proto, *module_config, run_backend_only));
768 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
769
770 if (!run_backend_only) {
771 TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
772 std::move(module), executor, options));
773 }
774
775 TF_ASSIGN_OR_RETURN(
776 std::unique_ptr<Executable> executable,
777 backend->compiler()->RunBackend(std::move(module), executor, options));
778
779 const auto& debug_opts = module_config->debug_options();
780 if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
781 debug_opts.xla_dump_hlo_snapshots()) {
782 auto hlo_proto = absl::make_unique<HloProto>();
783 *hlo_proto->mutable_hlo_module() = module_proto;
784 executable->set_hlo_proto(std::move(hlo_proto));
785 }
786
787 return std::move(executable);
788 }
789
Compile(const CompileRequest * arg,CompileResponse * result)790 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
791 VLOG(1) << "running compile request";
792 if (!arg->has_computation()) {
793 return InvalidArgument("computations may not be empty");
794 }
795 if (!arg->computation().has_host_program_shape()) {
796 return InvalidArgument("program shape may not be empty");
797 }
798
799 if (arg->execution_options().device_handles_size() > 1) {
800 return InvalidArgument(
801 "The compile request does not support multiple device handles.");
802 }
803
804 std::vector<Shape> argument_shapes;
805 argument_shapes.reserve(arg->input_shape_with_layout_size());
806 std::vector<const Shape*> argument_shape_ptrs;
807 for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
808 argument_shapes.push_back(Shape(shape_proto));
809 argument_shape_ptrs.push_back(&argument_shapes.back());
810 }
811 TF_ASSIGN_OR_RETURN(
812 std::unique_ptr<HloModuleConfig> module_config,
813 CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
814 argument_shape_ptrs, &arg->execution_options()));
815 VLOG(3) << "Compile created HloModuleConfig computation layout: "
816 << module_config->entry_computation_layout().ToString();
817
818 TF_ASSIGN_OR_RETURN(
819 std::unique_ptr<Executable> executable,
820 BuildExecutable(arg->computation(), std::move(module_config),
821 execute_backend_.get(),
822 execute_backend_->default_stream_executor(),
823 {/*device_allocator=*/nullptr}));
824
825 *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
826
827 VLOG(1) << "successfully completed 'compile' request";
828 return Status::OK();
829 }
830
Execute(const ExecuteRequest * arg,ExecuteResponse * result)831 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
832 VLOG(1) << "running execute request";
833 if (!arg->has_handle()) {
834 return InvalidArgument("execution handle should not be empty");
835 }
836 TF_ASSIGN_OR_RETURN(auto executable,
837 compilation_cache_.LookUp(arg->handle()));
838
839 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
840 SingleComputationDeviceHandle()));
841 TF_ASSIGN_OR_RETURN(
842 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
843 ResolveAndValidateArguments(arg->arguments(), replicas));
844
845 // Check that the replicated_arguments has the same shape and layout as the
846 // module config used when creating the executable.
847 const int64 num_module_args =
848 executable->module_config().entry_computation_layout().parameter_count();
849 if (num_module_args != arg->arguments_size()) {
850 return InvalidArgument(
851 "The executable expects %lld arguments, but sees %lld.",
852 num_module_args, arg->arguments_size());
853 }
854 for (int64 i = 0; i < num_module_args; i++) {
855 const Shape& shape_module =
856 executable->module_config().entry_computation_layout().parameter_shape(
857 i);
858 const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
859 if (!ShapeUtil::Equal(shape_module, shape_arg)) {
860 return InvalidArgumentStrCat(
861 "The executable expects the ", i, "th argument in shape ",
862 ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
863 ShapeUtil::HumanStringWithLayout(shape_arg));
864 }
865 }
866
867 TF_ASSIGN_OR_RETURN(auto stream,
868 execute_backend_->BorrowStream(
869 execute_backend_->default_stream_executor()));
870 HloSnapshot snapshot;
871 if (executable->dumping_snapshot()) {
872 *snapshot.mutable_hlo() = *executable->hlo_proto();
873 snapshot.set_execution_platform(execute_backend_->platform()->Name());
874 TF_RETURN_IF_ERROR(
875 RecordArguments(replicated_arguments.front(), stream.get(),
876 execute_backend_->transfer_manager(), &snapshot));
877 }
878
879 TF_ASSIGN_OR_RETURN(
880 *result->mutable_output(),
881 ExecuteAndRegisterResult(executable.get(), replicated_arguments,
882 execute_backend_.get(),
883 SingleComputationDeviceHandle(),
884 "result of " + executable->module().name(),
885 result->mutable_profile()));
886
887 if (executable->dumping_snapshot()) {
888 TF_ASSIGN_OR_RETURN(
889 const ShapedBuffer* result_buffer,
890 allocation_tracker_.ResolveForReplica(result->output(), 0));
891 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
892 execute_backend_->transfer_manager(),
893 &snapshot));
894 DumpHloSnapshotIfEnabled(executable->module(), snapshot);
895 }
896
897 VLOG(1) << "successfully completed 'execute' request";
898 return Status::OK();
899 }
900
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)901 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
902 WaitForExecutionResponse* result) {
903 TF_ASSIGN_OR_RETURN(const auto execution,
904 execution_tracker_.Resolve(arg->execution()));
905
906 TF_RETURN_IF_ERROR(execution->BlockUntilDone());
907
908 *result->mutable_output() = execution->result();
909 *result->mutable_profile() = execution->profile();
910
911 TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
912 VLOG(1) << "successfully completed 'wait-for-execution' request";
913 return Status::OK();
914 }
915
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)916 Status Service::TransferToClient(const TransferToClientRequest* arg,
917 TransferToClientResponse* result) {
918 TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
919 allocation_tracker_.ResolveForReplica(arg->data(), 0));
920
921 Shape return_shape;
922 if (arg->has_shape_with_layout()) {
923 return_shape = Shape(arg->shape_with_layout());
924 if (!LayoutUtil::HasLayout(return_shape)) {
925 return InvalidArgument("shape_with_layout must have layout if present.");
926 }
927 } else {
928 return_shape = Shape(shaped_buffer->on_host_shape());
929 }
930
931 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
932 shaped_buffer->device_ordinal()));
933
934 TF_ASSIGN_OR_RETURN(
935 Literal result_literal,
936 execute_backend_->transfer_manager()->TransferLiteralFromDevice(
937 stream.get(), *shaped_buffer));
938
939 if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
940 *result->mutable_literal() = result_literal.ToProto();
941 } else {
942 *result->mutable_literal() =
943 result_literal.Relayout(return_shape).ToProto();
944 }
945 return Status::OK();
946 }
947
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)948 Status Service::TransferToServer(const TransferToServerRequest* arg,
949 TransferToServerResponse* result) {
950 TF_ASSIGN_OR_RETURN(Literal literal,
951 Literal::CreateFromProto(arg->literal()));
952 const Shape& shape = literal.shape();
953
954 std::vector<se::StreamExecutor*> replicas;
955 if (arg->has_device_handle()) {
956 TF_ASSIGN_OR_RETURN(replicas,
957 Replicas(*execute_backend_, arg->device_handle()));
958 } else {
959 TF_ASSIGN_OR_RETURN(
960 replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
961 }
962
963 // Allocate memory in each replica and transfer the data to all replicas.
964 std::vector<ScopedShapedBuffer> replicated_buffers;
965 for (se::StreamExecutor* executor : replicas) {
966 TF_ASSIGN_OR_RETURN(
967 ScopedShapedBuffer shaped_buffer,
968 execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
969 shape, execute_backend_->memory_allocator(),
970 executor->device_ordinal()));
971 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
972 TF_RETURN_IF_ERROR(
973 execute_backend_->transfer_manager()->TransferLiteralToDevice(
974 stream.get(), literal, shaped_buffer));
975 replicated_buffers.emplace_back(std::move(shaped_buffer));
976 }
977 TF_ASSIGN_OR_RETURN(*result->mutable_data(),
978 allocation_tracker_.RegisterReplicatedBuffers(
979 std::move(replicated_buffers),
980 StrCat("TransferToServer literal of shape ",
981 ShapeUtil::HumanString(shape))));
982
983 return Status::OK();
984 }
985
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)986 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
987 TransferToInfeedResponse* result) {
988 const int64 replica_count = options_.number_of_replicas();
989 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
990 return FailedPrecondition(
991 "%s",
992 StrCat("The replica_id=", arg->replica_id(),
993 " on TransferToInfeedRequest not in range [0, replica_count=",
994 replica_count, ")."));
995 }
996
997 se::StreamExecutor* executor;
998 if (arg->has_device_handle()) {
999 TF_ASSIGN_OR_RETURN(auto replicas,
1000 Replicas(*execute_backend_, arg->device_handle()));
1001 executor = replicas[arg->replica_id()];
1002 } else {
1003 TF_ASSIGN_OR_RETURN(
1004 auto replicas,
1005 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1006 executor = replicas[arg->replica_id()];
1007 }
1008
1009 TF_ASSIGN_OR_RETURN(Literal literal,
1010 Literal::CreateFromProto(arg->literal()));
1011 return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1012 literal);
1013 }
1014
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1015 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1016 TransferFromOutfeedResponse* result) {
1017 const int64 replica_count = options_.number_of_replicas();
1018 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1019 return FailedPrecondition(
1020 "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1021 arg->replica_id(), replica_count);
1022 }
1023
1024 se::StreamExecutor* executor;
1025 if (arg->has_device_handle()) {
1026 TF_ASSIGN_OR_RETURN(auto replicas,
1027 Replicas(*execute_backend_, arg->device_handle()));
1028 executor = replicas[arg->replica_id()];
1029 } else {
1030 TF_ASSIGN_OR_RETURN(
1031 auto replicas,
1032 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1033 executor = replicas[arg->replica_id()];
1034 }
1035
1036 auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1037
1038 TF_RETURN_IF_ERROR(
1039 execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1040 executor, &literal));
1041 *result->mutable_literal() = literal.ToProto();
1042 return Status::OK();
1043 }
1044
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1045 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1046 ResetDeviceResponse* result) {
1047 return execute_backend_->ResetDevices();
1048 }
1049
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1050 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1051 ComputeConstantResponse* result) {
1052 if (!arg->has_computation()) {
1053 return InvalidArgument("computations may not be empty");
1054 }
1055 if (!arg->computation().has_host_program_shape()) {
1056 return InvalidArgument("program shape may not be empty");
1057 }
1058 if (arg->computation().host_program_shape().parameters_size() != 0) {
1059 return InvalidArgument(
1060 "constant computation may not depend on any parameters.");
1061 }
1062
1063 ProgramShape program_shape(arg->computation().host_program_shape());
1064 TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1065 absl::optional<Layout> output_layout;
1066 if (arg->has_output_layout()) {
1067 output_layout = Layout::CreateFromProto(arg->output_layout());
1068 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1069 *output_layout, program_shape.result()));
1070 }
1071
1072 HloModuleConfig config(program_shape);
1073
1074 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1075 CreateModuleFromProto(arg->computation(), config));
1076
1077 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1078 DynamicDimensionInference::Run(module.get()));
1079
1080 HloEvaluator evaluator;
1081 evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1082 TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1083
1084 // Since the result layout is non-effective to the Evaluator results, explicit
1085 // relayout here.
1086 //
1087 // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1088 if (output_layout.has_value()) {
1089 result_literal = result_literal.Relayout(*output_layout);
1090 }
1091 *result->mutable_literal() = result_literal.ToProto();
1092
1093 return Status::OK();
1094 }
1095
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1096 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1097 TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1098 allocation_tracker_.ResolveForReplica(arg->data(), 0));
1099 *result->mutable_shape() = buffer->on_host_shape().ToProto();
1100 return Status::OK();
1101 }
1102
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1103 Status Service::GetComputationGraphStats(
1104 const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1105 if (!arg->has_computation()) {
1106 return InvalidArgument("Computations may not be empty.");
1107 }
1108 if (!arg->computation().has_host_program_shape()) {
1109 return InvalidArgument("Program shape may not be empty.");
1110 }
1111
1112 HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1113 config.set_debug_options(arg->debug_options());
1114 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1115 CreateModuleFromProto(arg->computation(), config));
1116 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1117
1118 // Run HLO analysis to get the computation statistics.
1119 HloCostAnalysis analysis(
1120 execute_backend_->compiler()->ShapeSizeBytesFunction());
1121
1122 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1123
1124 ComputationStats stats;
1125 stats.set_flop_count(analysis.flop_count());
1126 stats.set_transcendental_count(analysis.transcendental_count());
1127 *result->mutable_stats() = stats;
1128 return Status::OK();
1129 }
1130
SingleComputationDeviceHandle() const1131 DeviceHandle Service::SingleComputationDeviceHandle() const {
1132 DeviceHandle device_handle;
1133 device_handle.set_handle(0);
1134 device_handle.set_device_count(1);
1135 return device_handle;
1136 }
1137
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1138 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1139 const Backend& backend, const DeviceHandle& device_handle) const {
1140 std::vector<se::StreamExecutor*> replicas;
1141 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1142 // From the computation placer, find out the device ids of the replicas for
1143 // the given device handle.
1144 TF_ASSIGN_OR_RETURN(
1145 int device_ordinal,
1146 backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1147 options_.number_of_replicas(),
1148 device_handle.device_count()));
1149 TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1150 replicas.push_back(executor);
1151 }
1152 return replicas;
1153 }
1154
1155 } // namespace xla
1156