1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/service.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/computation_placer.h"
32 #include "tensorflow/compiler/xla/service/dump.h"
33 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
34 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
35 #include "tensorflow/compiler/xla/service/executable.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
43 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
44 #include "tensorflow/compiler/xla/service/platform_util.h"
45 #include "tensorflow/compiler/xla/service/source_map_util.h"
46 #include "tensorflow/compiler/xla/service/stream_pool.h"
47 #include "tensorflow/compiler/xla/service/transfer_manager.h"
48 #include "tensorflow/compiler/xla/shape.h"
49 #include "tensorflow/compiler/xla/shape_layout.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/platform/env.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
61 #include "tensorflow/core/platform/types.h"
62 #include "tensorflow/core/util/ptr_util.h"
63 #include "tensorflow/stream_executor/device_memory_allocator.h"
64
65 namespace xla {
66 namespace {
67
68 using absl::StrCat;
69 using absl::StrFormat;
70
71 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
72 // performed on an HloModule.
73 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
74
75 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)76 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
77 se::Stream* stream, TransferManager* transfer_manager,
78 HloSnapshot* module) {
79 module->clear_arguments();
80 for (const ShapedBuffer* argument : arguments) {
81 TF_ASSIGN_OR_RETURN(
82 Literal literal,
83 transfer_manager->TransferLiteralFromDevice(stream, *argument));
84 *module->add_arguments() = literal.ToProto();
85 }
86 return Status::OK();
87 }
88
89 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)90 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
91 TransferManager* transfer_manager, HloSnapshot* module) {
92 module->clear_result();
93 TF_ASSIGN_OR_RETURN(
94 Literal literal,
95 transfer_manager->TransferLiteralFromDevice(stream, result));
96 *module->mutable_result() = literal.ToProto();
97 return Status::OK();
98 }
99
100 } // namespace
101
set_platform(se::Platform * platform)102 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
103 platform_ = platform;
104 return *this;
105 }
106
platform() const107 se::Platform* ServiceOptions::platform() const { return platform_; }
108
set_number_of_replicas(int number_of_replicas)109 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
110 number_of_replicas_ = number_of_replicas;
111 return *this;
112 }
113
number_of_replicas() const114 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
115
set_intra_op_parallelism_threads(int num_threads)116 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
117 int num_threads) {
118 intra_op_parallelism_threads_ = num_threads;
119 return *this;
120 }
121
intra_op_parallelism_threads() const122 int ServiceOptions::intra_op_parallelism_threads() const {
123 return intra_op_parallelism_threads_;
124 }
125
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)126 ServiceOptions& ServiceOptions::set_allowed_devices(
127 const absl::optional<std::set<int>>& allowed_devices) {
128 allowed_devices_ = allowed_devices;
129 return *this;
130 }
131
allowed_devices() const132 const absl::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
133 return allowed_devices_;
134 }
135
NewService(se::Platform * platform)136 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
137 se::Platform* platform) {
138 ServiceOptions default_options;
139 default_options.set_platform(platform);
140 return NewService(default_options);
141 }
142
NewService(const ServiceOptions & options)143 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
144 const ServiceOptions& options) {
145 se::Platform* platform = options.platform();
146 std::unique_ptr<Backend> execute_backend;
147 if (platform == nullptr) {
148 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
149 }
150 BackendOptions backend_options;
151 backend_options.set_platform(platform);
152 backend_options.set_allowed_devices(options.allowed_devices());
153 TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
154
155 std::unique_ptr<Service> service(
156 new Service(options, std::move(execute_backend)));
157 return std::move(service);
158 }
159
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)160 Service::Service(const ServiceOptions& options,
161 std::unique_ptr<Backend> execute_backend)
162 : options_(options),
163 allocation_tracker_(execute_backend.get()),
164 execute_backend_(std::move(execute_backend)) {
165 CHECK_GT(options_.number_of_replicas(), 0);
166 if (execute_backend_) {
167 if (execute_backend_->device_count() > 0) {
168 CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
169 << "Requested more replicas than there are devices.";
170 }
171 LOG(INFO) << StrFormat(
172 "XLA service %p initialized for platform %s (this does not guarantee "
173 "that XLA will be used). Devices:",
174 this, execute_backend_->platform()->Name());
175 auto stream_executors = execute_backend_->stream_executors();
176 for (int i = 0; i < execute_backend_->device_count(); ++i) {
177 se::StreamExecutor* executor = stream_executors.at(i);
178 const auto& description = executor->GetDeviceDescription();
179 LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
180 description.name(),
181 description.platform_version());
182 }
183 } else {
184 VLOG(1) << "XLA compile-only service constructed";
185 }
186 }
187
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)188 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
189 CreateChannelHandleResponse* result) {
190 TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
191 channel_tracker_.NewChannel(arg->channel_type()));
192 return Status::OK();
193 }
194
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)195 Status Service::Unregister(const UnregisterRequest* arg,
196 UnregisterResponse* result) {
197 Status status;
198 for (auto& data : arg->data()) {
199 Status unregister_status = allocation_tracker_.Unregister(data);
200 if (!unregister_status.ok() && status.ok()) {
201 status = unregister_status;
202 }
203 }
204 return status;
205 }
206
207 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)208 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
209 DeconstructTupleResponse* result) {
210 TF_ASSIGN_OR_RETURN(
211 std::vector<GlobalDataHandle> elements,
212 allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
213
214 for (auto& element : elements) {
215 *result->add_element_handles() = element;
216 }
217 return Status::OK();
218 }
219
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const220 Status Service::ValidateResultShape(const Shape& client_shape,
221 const Shape& result_shape) const {
222 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
223 if (!ShapeUtil::Compatible(client_shape, result_shape)) {
224 return InvalidArgument(
225 "Shape used to set computation result layout %s is not compatible "
226 "with result shape %s",
227 ShapeUtil::HumanStringWithLayout(client_shape),
228 ShapeUtil::HumanString(result_shape));
229 }
230 return Status::OK();
231 }
232
233 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const234 Service::ResolveAndValidateArguments(
235 absl::Span<const GlobalDataHandle* const> arguments,
236 absl::Span<se::StreamExecutor* const> stream_executors) const {
237 CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
238 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
239 replicated_arguments.resize(options_.number_of_replicas());
240 for (size_t i = 0; i < arguments.size(); ++i) {
241 auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
242 if (!buffer_status.ok()) {
243 return Status(buffer_status.status().code(),
244 StrCat(buffer_status.status().error_message(), ", ",
245 "failed to resolve allocation for parameter ", i));
246 }
247 auto replicated_buffers = buffer_status.ValueOrDie();
248 CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
249 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
250 const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
251 replicated_arguments[replica].push_back(shaped_buffer);
252 }
253 }
254 return replicated_arguments;
255 }
256
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)257 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
258 const ProgramShape& program_shape,
259 absl::Span<const Shape* const> argument_shapes,
260 const ExecutionOptions* execution_options,
261 const AotCompilationOptions* aot_options) {
262 int default_num_replicas = options_.number_of_replicas();
263 absl::optional<int> num_threads;
264 if (execute_backend_ != nullptr &&
265 execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
266 num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
267 }
268
269 return xla::CreateModuleConfig(program_shape, argument_shapes,
270 execution_options, default_num_replicas,
271 num_threads, aot_options);
272 }
273
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)274 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
275 const ProgramShape& program_shape,
276 absl::Span<const ShapedBuffer* const> arguments,
277 const ExecutionOptions& execution_options,
278 const AotCompilationOptions* aot_options) {
279 std::vector<const Shape*> argument_shapes;
280 for (const auto* arg : arguments) {
281 argument_shapes.push_back(&arg->on_host_shape());
282 }
283 return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
284 aot_options);
285 }
286
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)287 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
288 const std::vector<const HloModuleProto*>& module_protos,
289 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
290 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
291 const Compiler::CompileOptions& options, bool run_backend_only) {
292 VLOG(1) << StrFormat("BuildExecutable on service %p", this);
293
294 VLOG(1) << "Computations:";
295 for (const HloModuleProto* proto : module_protos) {
296 VLOG(1) << proto->name();
297 }
298
299 CHECK_EQ(module_protos.size(), module_configs.size());
300 auto module_group =
301 absl::make_unique<HloModuleGroup>(module_protos[0]->name());
302 for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
303 const HloModuleProto* proto = module_protos[i];
304 const HloModuleConfig& config = *module_configs[i];
305 TF_ASSIGN_OR_RETURN(
306 auto module, CreateModuleFromProto(*proto, config, run_backend_only));
307 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
308 module_group->push_back(std::move(module));
309 }
310
311 std::vector<std::unique_ptr<Executable>> executables;
312 if (!run_backend_only) {
313 TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
314 std::move(module_group),
315 std::move(executors), options));
316 } else {
317 auto modules = module_group->ConsumeModules();
318 for (std::unique_ptr<HloModule>& module : modules) {
319 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
320 backend->compiler()->RunBackend(
321 std::move(module), executors[0][0], options));
322 executables.push_back(std::move(executable));
323 }
324 }
325
326 return std::move(executables);
327 }
328
329 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const string> result_tags,ExecutionProfile * profile)330 Service::ExecuteParallelAndRegisterResult(
331 absl::Span<Executable* const> executables,
332 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
333 Backend* backend, absl::Span<const DeviceHandle> device_handles,
334 absl::Span<const string> result_tags, ExecutionProfile* profile) {
335 // Streams where the computation are launched, so we can wait on the streams
336 // to complete.
337 std::vector<StreamPool::Ptr> streams;
338 std::vector<std::unique_ptr<se::Timer>> timers;
339
340 // Global data handles for the computation results, one for each computation.
341 std::vector<GlobalDataHandle> result_handles;
342
343 // Device ID to stream executor, populated only with devices that are being
344 // profiled.
345 std::map<int64, se::Stream*> index_to_profiled_streams;
346
347 // Build DeviceAssignment for all cores based on the provided device handles.
348 DeviceAssignment device_assignment(options_.number_of_replicas(),
349 executables.size());
350 for (int64_t i = 0; i < executables.size(); i++) {
351 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
352 CHECK_EQ(replicas.size(), arguments[i].size());
353 for (int64_t replica = 0, end = replicas.size(); replica < end; ++replica) {
354 device_assignment(replica, i) = replicas[replica]->device_ordinal();
355 }
356 }
357
358 for (int64_t i = 0, end = executables.size(); i < end; i++) {
359 // Stream executors for the replicas of the current computation.
360 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
361 CHECK_EQ(replicas.size(), arguments[i].size());
362 std::vector<ScopedShapedBuffer> result_buffers;
363 for (int64_t replica = 0; replica < replicas.size(); ++replica) {
364 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
365 backend->BorrowStream(replicas[replica]));
366 streams.push_back(std::move(stream));
367
368 if (replica == 0 && profile != nullptr) {
369 timers.push_back(
370 absl::make_unique<se::Timer>(streams.back()->parent()));
371 streams.back()
372 ->InitTimer(timers.back().get())
373 .ThenStartTimer(timers.back().get());
374 CHECK(timers.front() != nullptr);
375 }
376
377 if (replica == 0 &&
378 executables[i]->module_config().debug_options().xla_hlo_profile() &&
379 executables[i]->hlo_profiling_enabled()) {
380 index_to_profiled_streams[i] = streams.back().get();
381 }
382
383 // Set up run options.
384 ExecutableRunOptions options;
385 options.set_stream(streams.back().get());
386 options.set_allocator(backend->memory_allocator());
387 options.set_intra_op_thread_pool(
388 backend->eigen_intra_op_thread_pool_device());
389 options.set_device_assignment(&device_assignment);
390 // Use run-time profile information from execution_profile on the 0th
391 // device.
392 if (i == 0) {
393 options.set_execution_profile(profile);
394 }
395 ServiceExecutableRunOptions run_options(options,
396 backend->StreamBorrower());
397
398 // Asynchronously launch the computation.
399 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
400 executables[i]->ExecuteAsyncOnStream(
401 &run_options, arguments[i][replica],
402 /*hlo_execution_profile=*/nullptr));
403
404 if (replica == 0 && profile != nullptr) {
405 streams.back()->ThenStopTimer(timers.back().get());
406 }
407
408 result_buffers.push_back(std::move(result));
409 }
410 TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
411 allocation_tracker_.RegisterReplicatedBuffers(
412 std::move(result_buffers), result_tags[i]));
413 result_handles.push_back(handle);
414 }
415
416 // Wait for all executions to complete.
417 for (int64_t i = 0, end = streams.size(); i < end; ++i) {
418 Status block_status = streams[i]->BlockHostUntilDone();
419 if (!block_status.ok()) {
420 return InternalError("failed to complete execution for stream %d: %s", i,
421 block_status.error_message());
422 }
423 }
424
425 if (profile != nullptr) {
426 CHECK(!timers.empty());
427 std::vector<uint64> timer_nanoseconds;
428 timer_nanoseconds.reserve(timers.size());
429 for (auto& timer : timers) {
430 timer_nanoseconds.push_back(timer->Nanoseconds());
431 }
432 uint64 nanoseconds =
433 *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
434
435 // Overall execution time (in nanoseconds) from the executor timer.
436 profile->set_compute_and_transfer_time_ns(nanoseconds);
437
438 // TODO(b/28123297): On GPU we end up including transfer time in
439 // the compute time this way. Instead, we should get the correct
440 // value by measuring it. Setting the field here at least lets
441 // benchmarks provide *some* value for GPU computations.
442 //
443 // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
444 // the compute time without the transfer time, so this way we get the
445 // correct compute time. We should instead have the correct value for
446 // compute_and_transfer_time and set compute_time to the compute time.
447 if (profile->compute_time_ns() == 0) {
448 profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
449 }
450 }
451
452 return result_handles;
453 }
454
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const string & result_tag,ExecutionProfile * profile)455 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
456 Executable* executable,
457 absl::Span<const std::vector<const ShapedBuffer*>> arguments,
458 Backend* backend, const DeviceHandle& device_handle,
459 const string& result_tag, ExecutionProfile* profile) {
460 // Set up streams.
461 std::vector<StreamPool::Ptr> streams;
462
463 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
464 TF_RET_CHECK(!replicas.empty());
465 for (se::StreamExecutor* executor : replicas) {
466 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
467 backend->BorrowStream(executor));
468 streams.push_back(std::move(stream));
469 }
470
471 DeviceAssignment device_assignment(options_.number_of_replicas(),
472 /*computation_count=*/1);
473 for (int64_t replica = 0; replica < replicas.size(); ++replica) {
474 device_assignment(replica, 0) = replicas[replica]->device_ordinal();
475 }
476
477 // Set up run options.
478 std::vector<ServiceExecutableRunOptions> run_options;
479 for (const StreamPool::Ptr& stream : streams) {
480 ExecutableRunOptions options;
481 options.set_stream(stream.get());
482 options.set_device_ordinal(stream->parent()->device_ordinal());
483 options.set_allocator(backend->memory_allocator());
484 options.set_intra_op_thread_pool(
485 backend->eigen_intra_op_thread_pool_device());
486 options.set_device_assignment(&device_assignment);
487 options.set_execution_profile(profile);
488 run_options.emplace_back(options, backend->StreamBorrower());
489 }
490
491 if (options_.number_of_replicas() == 1) {
492 TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
493 &run_options[0], arguments[0]));
494 return allocation_tracker_.Register(std::move(result), result_tag);
495 }
496
497 // TODO(b/69985541): Support profiling also on this path.
498
499 std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
500 for (const auto& arg : arguments) {
501 replicated_arguments.push_back(arg);
502 }
503
504 TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
505 run_options, replicated_arguments));
506 TF_RET_CHECK(!results.empty());
507 return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
508 result_tag);
509 }
510
GetExecutors(const ExecutionOptions & execution_options,int64_t requests_size,int64_t request_index) const511 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
512 const ExecutionOptions& execution_options, int64_t requests_size,
513 int64_t request_index) const {
514 if (execution_options.device_handles().empty()) {
515 return FailedPrecondition(
516 "device handles must be given to execute parallel computations");
517 }
518 if (requests_size > 1 && execution_options.device_handles_size() > 1) {
519 return InvalidArgument(
520 "Parallel requests with multiple device handles is not supported. "
521 "Found %d parallel requests, with request %d containing %d device "
522 "handles.",
523 requests_size, request_index, execution_options.device_handles_size());
524 }
525 std::vector<se::StreamExecutor*> executors;
526 for (const auto& device_handle : execution_options.device_handles()) {
527 TF_ASSIGN_OR_RETURN(auto replicas,
528 Replicas(*execute_backend_, device_handle));
529 se::StreamExecutor* executor = replicas[0];
530 CHECK(executor != nullptr);
531 executors.push_back(executor);
532 }
533 return executors;
534 }
535
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const536 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
537 const ExecutionOptions& execution_options,
538 absl::Span<const GlobalDataHandle* const> arguments) const {
539 // Resolve the allocations for the arguments of the computation, and create
540 // a vector of device memory offsets for the arguments from the allocations.
541 // In the case of partitioned computations, assume all arguments go on the
542 // zeroth core.
543 TF_ASSIGN_OR_RETURN(
544 auto replicas,
545 Replicas(*execute_backend_, execution_options.device_handles(0)));
546 TF_ASSIGN_OR_RETURN(
547 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
548 ResolveAndValidateArguments(arguments, replicas));
549 return replicated_arguments;
550 }
551
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)552 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
553 ExecuteParallelResponse* result) {
554 VLOG(1) << "running execute-graph-parallel request";
555
556 std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
557 std::vector<std::vector<se::StreamExecutor*>> all_executors;
558 std::vector<const HloModuleProto*> module_protos;
559 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
560 std::vector<string> computation_names;
561 std::vector<DeviceHandle> device_handles;
562
563 int num_requested_devices =
564 std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
565 [](int a, const ExecuteGraphRequest& r) -> int {
566 return a + r.execution_options().device_handles_size();
567 });
568 if (num_requested_devices * options_.number_of_replicas() >
569 execute_backend_->device_count()) {
570 return FailedPrecondition(
571 "there are not enough stream executors to execute %d computations",
572 num_requested_devices);
573 }
574
575 for (int64_t i = 0; i < arg->requests_size(); ++i) {
576 // Get the stream executor for the i'th computation. This stream executor
577 // is one of the executors to run the replicated computation.
578 const ExecutionOptions& execution_options =
579 arg->requests(i).execution_options();
580 const ExecuteGraphRequest& request = arg->requests(i);
581 TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
582 TF_RET_CHECK(request.computation().has_host_program_shape())
583 << "program shape may not be empty";
584
585 // Get the executors.
586 TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
587 arg->requests_size(), i));
588
589 // Get the replicated arguments.
590 TF_ASSIGN_OR_RETURN(auto replicated_arguments,
591 GetArguments(execution_options, request.arguments()));
592
593 // Create an HloModuleConfig object for the computation, given the shape of
594 // the program and the argument allocations. Here, we care only about the
595 // shapes of the arguments, so, it is sufficient to use the arguments of
596 // replica 0.
597 TF_ASSIGN_OR_RETURN(
598 std::unique_ptr<HloModuleConfig> module_config,
599 CreateModuleConfig(
600 ProgramShape{request.computation().host_program_shape()},
601 replicated_arguments.front(), request.execution_options()));
602 VLOG(3)
603 << "ExecuteGraphParallel created HloModuleConfig computation layout: "
604 << module_config->entry_computation_layout().ToString();
605
606 // Adds to the vectors to build and execute the computations after the loop.
607 all_arguments.push_back(replicated_arguments);
608 all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
609 module_protos.push_back(&request.computation());
610 module_configs.push_back(std::move(module_config));
611 computation_names.insert(computation_names.end(), executors.size(),
612 request.computation().name());
613 all_executors.push_back(executors);
614 device_handles.insert(device_handles.end(),
615 execution_options.device_handles().begin(),
616 execution_options.device_handles().end());
617 }
618
619 // Build the HloModules and compile to generate the executables.
620 //
621 // TODO(jlebar): There's currently no way to pass a device allocator to
622 // ExecuteGraphParallel, so we have to pass a null device_allocator below.
623 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
624 BuildExecutables(module_protos, std::move(module_configs),
625 execute_backend_.get(), all_executors,
626 {/*device_allocator=*/nullptr}));
627 std::vector<Executable*> executable_ptrs;
628 executable_ptrs.reserve(executables.size());
629 for (const auto& executable : executables) {
630 executable_ptrs.push_back(executable.get());
631 }
632
633 std::vector<HloSnapshot> snapshots;
634 snapshots.resize(executable_ptrs.size());
635 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
636 if (executable_ptrs[i]->dumping_snapshot()) {
637 *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
638 TF_ASSIGN_OR_RETURN(auto stream,
639 execute_backend_->BorrowStream(
640 all_executors[i][0]->device_ordinal()));
641 TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
642 execute_backend_->transfer_manager(),
643 &snapshots[i]));
644 }
645 }
646
647 // If we have multiple executables to run, execute them all in parallel. But
648 // if we only have one executable, execute it using the vanilla, non-parallel
649 // call.
650 //
651 // We do this because the Client API uses ExecuteGraphParallel when it wants
652 // to compile and run one computation without caching the executable, but not
653 // all backends support the async StreamExecutor API required by
654 // ExecuteParallelAndRegisterResult.
655 //
656 // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
657 // basically the same thing.
658 ExecutionProfile profile;
659 std::vector<GlobalDataHandle> outputs;
660 Status execution_status = Status::OK();
661
662 if (executable_ptrs.size() == 1) {
663 StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
664 executable_ptrs[0], all_arguments[0], execute_backend_.get(),
665 device_handles[0], computation_names[0], &profile);
666 if (output_or_status.ok()) {
667 outputs.push_back(std::move(output_or_status).ValueOrDie());
668 } else {
669 execution_status = output_or_status.status();
670 }
671 } else {
672 StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
673 ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
674 execute_backend_.get(), device_handles,
675 computation_names, &profile);
676 if (outputs_or_status.ok()) {
677 outputs = std::move(outputs_or_status).ValueOrDie();
678 } else {
679 execution_status = outputs_or_status.status();
680 }
681 }
682
683 if (!execution_status.ok()) {
684 // Execution failed so we don't have the results. Dump the HLO snapshot
685 // with just the program arguments.
686 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
687 DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
688 }
689 }
690
691 TF_RETURN_IF_ERROR(execution_status);
692
693 for (const GlobalDataHandle& output : outputs) {
694 ExecuteResponse response;
695 *response.mutable_output() = output;
696 *response.mutable_profile() = profile;
697 *result->add_responses() = response;
698 }
699
700 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
701 Executable* executable = executable_ptrs[i];
702 if (executable->dumping_snapshot()) {
703 TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
704 allocation_tracker_.ResolveForReplica(outputs[i], 0));
705 TF_ASSIGN_OR_RETURN(auto stream,
706 execute_backend_->BorrowStream(all_executors[i][0]));
707 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
708 execute_backend_->transfer_manager(),
709 &snapshots[i]));
710 DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
711 }
712 }
713
714 VLOG(1) << "successfully completed 'execute-graph-parallel' request";
715 return Status::OK();
716 }
717
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)718 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
719 GetDeviceHandlesResponse* result) {
720 const int64_t available_device_count = execute_backend_->device_count();
721 const int64_t replica_count = options_.number_of_replicas();
722 if (replica_count <= 0) {
723 return FailedPrecondition("Replica count must be a positive integer");
724 }
725 if (available_device_count < arg->device_count() * replica_count) {
726 return ResourceExhausted(
727 "Requested logical device count (%d) with replica count (%d) exceeds "
728 "the number of available physical devices on the target (%d)",
729 arg->device_count(), replica_count, available_device_count);
730 }
731
732 for (int64_t i = 0; i < arg->device_count(); ++i) {
733 DeviceHandle device_handle;
734 device_handle.set_handle(i);
735 device_handle.set_device_count(arg->device_count());
736 *result->add_device_handles() = device_handle;
737 }
738
739 return Status::OK();
740 }
741
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)742 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
743 const HloModuleProto& module_proto,
744 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
745 se::StreamExecutor* executor, const Compiler::CompileOptions& options,
746 bool run_backend_only) {
747 VLOG(1) << StrFormat(
748 "BuildExecutable on service %p with serialized module proto: %s", this,
749 module_proto.name());
750
751 TF_ASSIGN_OR_RETURN(
752 std::unique_ptr<HloModule> module,
753 CreateModuleFromProto(module_proto, *module_config, run_backend_only));
754 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
755
756 if (!run_backend_only) {
757 TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
758 std::move(module), executor, options));
759 }
760
761 TF_ASSIGN_OR_RETURN(
762 std::unique_ptr<Executable> executable,
763 backend->compiler()->RunBackend(std::move(module), executor, options));
764
765 return std::move(executable);
766 }
767
Compile(const CompileRequest * arg,CompileResponse * result)768 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
769 VLOG(1) << "running compile request";
770 if (!arg->has_computation()) {
771 return InvalidArgument("computations may not be empty");
772 }
773 if (!arg->computation().has_host_program_shape()) {
774 return InvalidArgument("program shape may not be empty");
775 }
776
777 if (arg->execution_options().device_handles_size() > 1) {
778 return InvalidArgument(
779 "The compile request does not support multiple device handles.");
780 }
781
782 std::vector<Shape> argument_shapes;
783 argument_shapes.reserve(arg->input_shape_with_layout_size());
784 std::vector<const Shape*> argument_shape_ptrs;
785 for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
786 argument_shapes.push_back(Shape(shape_proto));
787 argument_shape_ptrs.push_back(&argument_shapes.back());
788 }
789 TF_ASSIGN_OR_RETURN(
790 std::unique_ptr<HloModuleConfig> module_config,
791 CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
792 argument_shape_ptrs, &arg->execution_options()));
793 VLOG(3) << "Compile created HloModuleConfig computation layout: "
794 << module_config->entry_computation_layout().ToString();
795
796 TF_ASSIGN_OR_RETURN(
797 std::unique_ptr<Executable> executable,
798 BuildExecutable(arg->computation(), std::move(module_config),
799 execute_backend_.get(),
800 execute_backend_->default_stream_executor(),
801 {/*device_allocator=*/nullptr}));
802
803 *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
804
805 VLOG(1) << "successfully completed 'compile' request";
806 return Status::OK();
807 }
808
Execute(const ExecuteRequest * arg,ExecuteResponse * result)809 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
810 VLOG(1) << "running execute request";
811 if (!arg->has_handle()) {
812 return InvalidArgument("execution handle should not be empty");
813 }
814 TF_ASSIGN_OR_RETURN(auto executable,
815 compilation_cache_.LookUp(arg->handle()));
816
817 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
818 SingleComputationDeviceHandle()));
819 TF_ASSIGN_OR_RETURN(
820 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
821 ResolveAndValidateArguments(arg->arguments(), replicas));
822
823 // Check that the replicated_arguments has the same shape and layout as the
824 // module config used when creating the executable.
825 const int64_t num_module_args =
826 executable->module_config().entry_computation_layout().parameter_count();
827 if (num_module_args != arg->arguments_size()) {
828 return InvalidArgument(
829 "The executable expects %lld arguments, but sees %lld.",
830 num_module_args, arg->arguments_size());
831 }
832 for (int64_t i = 0; i < num_module_args; i++) {
833 const Shape& shape_module =
834 executable->module_config().entry_computation_layout().parameter_shape(
835 i);
836 const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
837 if (!ShapeUtil::Equal(shape_module, shape_arg)) {
838 return InvalidArgumentStrCat(
839 "The executable expects the ", i, "th argument in shape ",
840 ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
841 ShapeUtil::HumanStringWithLayout(shape_arg));
842 }
843 }
844
845 TF_ASSIGN_OR_RETURN(auto stream,
846 execute_backend_->BorrowStream(
847 execute_backend_->default_stream_executor()));
848 HloSnapshot snapshot;
849 if (executable->dumping_snapshot()) {
850 *snapshot.mutable_hlo() = *executable->hlo_proto();
851 snapshot.set_execution_platform(execute_backend_->platform()->Name());
852 TF_RETURN_IF_ERROR(
853 RecordArguments(replicated_arguments.front(), stream.get(),
854 execute_backend_->transfer_manager(), &snapshot));
855 }
856
857 TF_ASSIGN_OR_RETURN(
858 *result->mutable_output(),
859 ExecuteAndRegisterResult(executable.get(), replicated_arguments,
860 execute_backend_.get(),
861 SingleComputationDeviceHandle(),
862 "result of " + executable->module().name(),
863 result->mutable_profile()));
864
865 if (executable->dumping_snapshot()) {
866 TF_ASSIGN_OR_RETURN(
867 const ShapedBuffer* result_buffer,
868 allocation_tracker_.ResolveForReplica(result->output(), 0));
869 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
870 execute_backend_->transfer_manager(),
871 &snapshot));
872 DumpHloSnapshotIfEnabled(executable->module(), snapshot);
873 }
874
875 VLOG(1) << "successfully completed 'execute' request";
876 return Status::OK();
877 }
878
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)879 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
880 WaitForExecutionResponse* result) {
881 TF_ASSIGN_OR_RETURN(const auto execution,
882 execution_tracker_.Resolve(arg->execution()));
883
884 TF_RETURN_IF_ERROR(execution->BlockUntilDone());
885
886 *result->mutable_output() = execution->result();
887 *result->mutable_profile() = execution->profile();
888
889 TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
890 VLOG(1) << "successfully completed 'wait-for-execution' request";
891 return Status::OK();
892 }
893
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)894 Status Service::TransferToClient(const TransferToClientRequest* arg,
895 TransferToClientResponse* result) {
896 TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
897 allocation_tracker_.ResolveForReplica(arg->data(), 0));
898
899 Shape return_shape;
900 if (arg->has_shape_with_layout()) {
901 return_shape = Shape(arg->shape_with_layout());
902 if (!LayoutUtil::HasLayout(return_shape)) {
903 return InvalidArgument("shape_with_layout must have layout if present.");
904 }
905 } else {
906 return_shape = Shape(shaped_buffer->on_host_shape());
907 }
908
909 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
910 shaped_buffer->device_ordinal()));
911
912 TF_ASSIGN_OR_RETURN(
913 Literal result_literal,
914 execute_backend_->transfer_manager()->TransferLiteralFromDevice(
915 stream.get(), *shaped_buffer));
916
917 if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
918 *result->mutable_literal() = result_literal.ToProto();
919 } else {
920 *result->mutable_literal() =
921 result_literal.Relayout(return_shape).ToProto();
922 }
923 return Status::OK();
924 }
925
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)926 Status Service::TransferToServer(const TransferToServerRequest* arg,
927 TransferToServerResponse* result) {
928 TF_ASSIGN_OR_RETURN(Literal literal,
929 Literal::CreateFromProto(arg->literal()));
930 const Shape& shape = literal.shape();
931
932 std::vector<se::StreamExecutor*> replicas;
933 if (arg->has_device_handle()) {
934 TF_ASSIGN_OR_RETURN(replicas,
935 Replicas(*execute_backend_, arg->device_handle()));
936 } else {
937 TF_ASSIGN_OR_RETURN(
938 replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
939 }
940
941 // Allocate memory in each replica and transfer the data to all replicas.
942 std::vector<ScopedShapedBuffer> replicated_buffers;
943 for (se::StreamExecutor* executor : replicas) {
944 TF_ASSIGN_OR_RETURN(
945 ScopedShapedBuffer shaped_buffer,
946 execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
947 shape, execute_backend_->memory_allocator(),
948 executor->device_ordinal()));
949 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
950 TF_RETURN_IF_ERROR(
951 execute_backend_->transfer_manager()->TransferLiteralToDevice(
952 stream.get(), literal, shaped_buffer));
953 replicated_buffers.emplace_back(std::move(shaped_buffer));
954 }
955 TF_ASSIGN_OR_RETURN(*result->mutable_data(),
956 allocation_tracker_.RegisterReplicatedBuffers(
957 std::move(replicated_buffers),
958 StrCat("TransferToServer literal of shape ",
959 ShapeUtil::HumanString(shape))));
960
961 return Status::OK();
962 }
963
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)964 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
965 TransferToInfeedResponse* result) {
966 const int64_t replica_count = options_.number_of_replicas();
967 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
968 return FailedPrecondition(
969 "%s",
970 StrCat("The replica_id=", arg->replica_id(),
971 " on TransferToInfeedRequest not in range [0, replica_count=",
972 replica_count, ")."));
973 }
974
975 se::StreamExecutor* executor;
976 if (arg->has_device_handle()) {
977 TF_ASSIGN_OR_RETURN(auto replicas,
978 Replicas(*execute_backend_, arg->device_handle()));
979 executor = replicas[arg->replica_id()];
980 } else {
981 TF_ASSIGN_OR_RETURN(
982 auto replicas,
983 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
984 executor = replicas[arg->replica_id()];
985 }
986
987 TF_ASSIGN_OR_RETURN(Literal literal,
988 Literal::CreateFromProto(arg->literal()));
989 return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
990 literal);
991 }
992
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)993 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
994 TransferFromOutfeedResponse* result) {
995 const int64_t replica_count = options_.number_of_replicas();
996 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
997 return FailedPrecondition(
998 "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
999 arg->replica_id(), replica_count);
1000 }
1001
1002 se::StreamExecutor* executor;
1003 if (arg->has_device_handle()) {
1004 TF_ASSIGN_OR_RETURN(auto replicas,
1005 Replicas(*execute_backend_, arg->device_handle()));
1006 executor = replicas[arg->replica_id()];
1007 } else {
1008 TF_ASSIGN_OR_RETURN(
1009 auto replicas,
1010 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1011 executor = replicas[arg->replica_id()];
1012 }
1013
1014 auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1015
1016 TF_RETURN_IF_ERROR(
1017 execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1018 executor, &literal));
1019 *result->mutable_literal() = literal.ToProto();
1020 return Status::OK();
1021 }
1022
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1023 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1024 ResetDeviceResponse* result) {
1025 return execute_backend_->ResetDevices();
1026 }
1027
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1028 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1029 ComputeConstantResponse* result) {
1030 if (!arg->has_computation()) {
1031 return InvalidArgument("computations may not be empty");
1032 }
1033 if (!arg->computation().has_host_program_shape()) {
1034 return InvalidArgument("program shape may not be empty");
1035 }
1036 if (arg->computation().host_program_shape().parameters_size() != 0) {
1037 return InvalidArgument(
1038 "constant computation may not depend on any parameters.");
1039 }
1040
1041 ProgramShape program_shape(arg->computation().host_program_shape());
1042 TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1043 absl::optional<Layout> output_layout;
1044 if (arg->has_output_layout()) {
1045 output_layout = Layout::CreateFromProto(arg->output_layout());
1046 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1047 *output_layout, program_shape.result()));
1048 }
1049
1050 HloModuleConfig config(program_shape);
1051
1052 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1053 CreateModuleFromProto(arg->computation(), config));
1054 DynamicPadder dynamic_padder;
1055 TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status());
1056
1057 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1058 DynamicDimensionInference::Run(module.get()));
1059
1060 HloEvaluator evaluator;
1061 evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1062 evaluator.set_custom_call_handler(
1063 [](HloInstruction* custom_call,
1064 absl::Span<const Literal*> operands) -> StatusOr<Literal> {
1065 if (custom_call->custom_call_target() == "SliceToDynamic") {
1066 auto result = operands[0]->Clone();
1067 for (int64_t i = 0; i < result.shape().rank(); ++i) {
1068 result.SetDynamicSize(i, operands[1 + i]->Get<int32>({}));
1069 }
1070 return result.ToStatic();
1071 }
1072 return Unimplemented("Custom call %s is not supported: %s",
1073 custom_call->custom_call_target(),
1074 custom_call->ToString());
1075 });
1076 TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1077
1078 // Since the result layout is non-effective to the Evaluator results, explicit
1079 // relayout here.
1080 //
1081 // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1082 if (output_layout.has_value()) {
1083 result_literal = result_literal.Relayout(*output_layout);
1084 }
1085 *result->mutable_literal() = result_literal.ToProto();
1086
1087 return Status::OK();
1088 }
1089
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1090 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1091 TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1092 allocation_tracker_.ResolveForReplica(arg->data(), 0));
1093 *result->mutable_shape() = buffer->on_host_shape().ToProto();
1094 return Status::OK();
1095 }
1096
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1097 Status Service::GetComputationGraphStats(
1098 const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1099 if (!arg->has_computation()) {
1100 return InvalidArgument("Computations may not be empty.");
1101 }
1102 if (!arg->computation().has_host_program_shape()) {
1103 return InvalidArgument("Program shape may not be empty.");
1104 }
1105
1106 HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1107 config.set_debug_options(arg->debug_options());
1108 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1109 CreateModuleFromProto(arg->computation(), config));
1110 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1111
1112 // Run HLO analysis to get the computation statistics.
1113 HloCostAnalysis analysis(
1114 execute_backend_->compiler()->ShapeSizeBytesFunction());
1115
1116 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1117
1118 ComputationStats stats;
1119 stats.set_flop_count(analysis.flop_count());
1120 stats.set_transcendental_count(analysis.transcendental_count());
1121 *result->mutable_stats() = stats;
1122 return Status::OK();
1123 }
1124
SingleComputationDeviceHandle() const1125 DeviceHandle Service::SingleComputationDeviceHandle() const {
1126 DeviceHandle device_handle;
1127 device_handle.set_handle(0);
1128 device_handle.set_device_count(1);
1129 return device_handle;
1130 }
1131
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1132 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1133 const Backend& backend, const DeviceHandle& device_handle) const {
1134 std::vector<se::StreamExecutor*> replicas;
1135 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1136 // From the computation placer, find out the device ids of the replicas for
1137 // the given device handle.
1138 TF_ASSIGN_OR_RETURN(
1139 int device_ordinal,
1140 backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1141 options_.number_of_replicas(),
1142 device_handle.device_count()));
1143 TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1144 replicas.push_back(executor);
1145 }
1146 return replicas;
1147 }
1148
1149 } // namespace xla
1150