• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <variant>
25 #include <vector>
26 
27 #include "absl/cleanup/cleanup.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/synchronization/mutex.h"
30 #include "mlir/Parser/Parser.h"  // from @llvm-project
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
36 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
42 #include "tensorflow/compiler/xla/service/transfer_manager.h"
43 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
44 #include "tensorflow/compiler/xla/shape_tree.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/casts.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/stream_executor/platform.h"
55 
56 #if XLA_ENABLE_XLIR
57 #include "tensorflow/compiler/xla/mlir/transforms/runtime/compilation_pipeline.h"
58 #include "tensorflow/compiler/xla/runtime/diagnostics.h"
59 #include "tensorflow/compiler/xla/runtime/executable.h"
60 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
61 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
62 #include "tfrt/init_tfrt_dialects.h"  // from @tf_runtime
63 #endif  // XLA_ENABLE_XLIR
64 
65 namespace xla {
66 namespace gpu {
67 
IsJitRtExecutableEnabled(const HloModuleConfig & config)68 bool IsJitRtExecutableEnabled(const HloModuleConfig& config) {
69 #if !XLA_ENABLE_XLIR
70   CHECK(!config.debug_options().xla_gpu_jitrt_executable())
71       << "Failed to enable JitRt backend, because it was not compiled.";
72 #endif  // !XLA_ENABLE_XLIR
73   return config.debug_options().xla_gpu_jitrt_executable();
74 }
75 
76 namespace {
77 
78 using ::tensorflow::profiler::ScopedAnnotation;
79 
NeedsAsyncCommsStream(Thunk & thunk)80 bool NeedsAsyncCommsStream(Thunk& thunk) {
81   switch (thunk.kind()) {
82     case Thunk::Kind::kNcclAllReduceStart:
83     case Thunk::Kind::kNcclAllReduceDone:
84       return true;
85     default:
86       return false;
87   }
88 }
89 
90 }  // namespace
91 
92 #if XLA_ENABLE_XLIR
93 
94 class GpuExecutable::JitRtExecutable {
95  public:
Create(OwnedJitRtProgram program)96   static StatusOr<JitRtExecutable*> Create(OwnedJitRtProgram program) {
97     // Options for the default JitRt compilation pipeline.
98     runtime::CompilationPipelineOptions copts;
99 
100     // Populate mapping from XLA (SE) enums/structs type id to symbol names.
101     copts.populate_type_id_names = PopulateXlaTypeIdNames;
102 
103     // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls.
104     copts.populate_attr_encodings = PopulateLmhloToXlaAttrEncoding;
105 
106     // Options for constructing XLA runtime JitExecutable.
107     runtime::JitExecutable::Options opts;
108     opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
109     opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
110       runtime::RegisterDefaultXlaRuntimeDialects(registry);
111       // For the encoding of attributes to custom calls.
112       registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
113     };
114 
115     // Register XLA Gpu runtime custom calls with the linker.
116     opts.compiler.symbols_binding = runtime::ToSymbolsBinding(
117         JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
118 
119     // We just use the default compilation pipeline provided by the XLA runtime.
120     // Alternatively instead of having a separate JitRtProgram (LMHLO lowered to
121     // canonical dialects), we can assemble a pipeline that will compile
122     // starting from the LMHLO dialect. However this intermediate step helps
123     // with debugging, by materializing IR with XLA runtime custom calls.
124     opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
125       runtime::CreateDefaultXlaRuntimeCompilationPipeline(pm, copts);
126     };
127 
128     // TODO(b/241296710): LLVM optimizations interact badly with the memory
129     // loads and stores pattern generated in very large XLA programs, and can
130     // take minutes to run. Currently we do not expect any expensive code
131     // running on the host, so we can safely disable optimization passes.
132     opts.compiler.jit_code_opt_level = llvm::CodeGenOpt::None;
133 
134     // Instantiate new JitExecutable from the MLIR source.
135     auto jit_executable = runtime::JitExecutable::Instantiate(
136         program->module, program->entry_point, opts);
137     if (auto err = jit_executable.takeError())
138       return InternalError("Failed to compile JitRt program: %s",
139                            tfrt::StrCat(err));
140 
141     // Pass ownership to the GpuExecutable.
142     return new JitRtExecutable(
143         std::move(program->buffer_sizes),
144         std::make_unique<runtime::JitExecutable>(std::move(*jit_executable)),
145         std::move(program->debug_options));
146   }
147 
148   // Create JitRtExecutable from the AOT compiled binary.
Create(absl::Span<const int64_t> buffer_sizes,runtime::Executable executable,DebugOptions debug_options)149   static StatusOr<JitRtExecutable*> Create(
150       absl::Span<const int64_t> buffer_sizes, runtime::Executable executable,
151       DebugOptions debug_options) {
152     // Pass ownership to the GpuExecutable.
153     return new JitRtExecutable(
154         std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
155         std::make_unique<runtime::Executable>(std::move(executable)),
156         std::move(debug_options));
157   }
158 
kernels_cache()159   JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
gemm_configs_cache()160   JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
collectives()161   JitRtCollectiveSupport& collectives() { return collectives_; }
162 
executable()163   runtime::Executable& executable() {
164     // Exactly one kind of `Executable` should be available at run time.
165     assert((default_executable_ || executable_) &&
166            !(default_executable_ && executable_));
167     return default_executable_ ? *default_executable_ : *executable_;
168   }
169 
170   // We pass a pointer to the buffer size to the compiled function, so we return
171   // a reference to a stable memory location.
buffer_size(size_t offset) const172   const int64_t& buffer_size(size_t offset) const {
173     return buffer_sizes_[offset];
174   }
175 
debug_options() const176   const DebugOptions& debug_options() const { return debug_options_; }
177 
178  private:
JitRtExecutable(std::vector<int64_t> buffer_sizes,std::unique_ptr<runtime::JitExecutable> jit_executable,DebugOptions debug_options)179   JitRtExecutable(std::vector<int64_t> buffer_sizes,
180                   std::unique_ptr<runtime::JitExecutable> jit_executable,
181                   DebugOptions debug_options)
182       : buffer_sizes_(std::move(buffer_sizes)),
183         jit_executable_(std::move(jit_executable)),
184         default_executable_(&jit_executable_->DefaultExecutable().get()),
185         debug_options_(std::move(debug_options)) {}
186 
JitRtExecutable(std::vector<int64_t> buffer_sizes,std::unique_ptr<runtime::Executable> aot_executable,DebugOptions debug_options)187   JitRtExecutable(std::vector<int64_t> buffer_sizes,
188                   std::unique_ptr<runtime::Executable> aot_executable,
189                   DebugOptions debug_options)
190       : buffer_sizes_(std::move(buffer_sizes)),
191         aot_executable_(std::move(aot_executable)),
192         executable_(aot_executable_.get()),
193         debug_options_(std::move(debug_options)) {}
194 
195   std::vector<int64_t> buffer_sizes_;
196 
197   // In JIT compilation mode the `JitExecutable` owns the default `Executable`.
198   std::unique_ptr<runtime::JitExecutable> jit_executable_;
199   runtime::Executable* default_executable_ = nullptr;
200 
201   // In AOT compilation mode we directly own the `Executable`.
202   std::unique_ptr<runtime::Executable> aot_executable_;
203   runtime::Executable* executable_ = nullptr;
204 
205   DebugOptions debug_options_;
206 
207   // Keep a cache of kernels instantiated by this executable.
208   JitRtKernelsCache kernels_cache_;
209 
210   // Keep a cache of gemm configs for all gemm operation in the program.
211   JitRtGemmConfigCache gemm_configs_cache_;
212 
213   // Support for running collective operations.
214   JitRtCollectiveSupport collectives_;
215 };
216 #endif  // XLA_ENABLE_XLIR
217 
Create(Params params)218 StatusOr<std::unique_ptr<GpuExecutable>> GpuExecutable::Create(Params params) {
219   auto executable = std::move(params.executable);
220   std::unique_ptr<GpuExecutable> result(new GpuExecutable(std::move(params)));
221 
222   if (std::holds_alternative<OwnedThunkSequence>(executable)) {
223     result->thunks_ = std::move(std::get<OwnedThunkSequence>(executable));
224     return result;
225   }
226 
227 #if XLA_ENABLE_XLIR
228   if (std::holds_alternative<OwnedJitRtProgram>(executable)) {
229     auto& program = std::get<OwnedJitRtProgram>(executable);
230     TF_ASSIGN_OR_RETURN(result->jitrt_executable_,
231                         JitRtExecutable::Create(std::move(program)));
232     return result;
233   }
234 #endif  // XLA_ENABLE_XLIR
235 
236   return InternalError("No XLA gpu executable was provided");
237 }
238 
239 // Implementation note: HLO profiling is always enabled for GPU executables,
240 // since we can use timers around thunks.
GpuExecutable(GpuExecutable::Params params)241 GpuExecutable::GpuExecutable(GpuExecutable::Params params)
242     : Executable(std::move(params.debug_module)),
243       text_(std::move(params.asm_text)),
244       binary_(std::move(params.binary)),
245       gpu_version_(params.gpu_version),
246       entry_func_attrs_(params.entry_func_attrs),
247       module_name_(params.module_name),
248       output_shape_(params.output_shape),
249       allocations_(std::move(params.allocations)),
250       debug_buffer_assignment_(std::move(params.debug_buffer_assignment)),
251       verbose_buffer_assignment_string_dumper_(
252           params.verbose_buffer_assignment_string_dumper),
253       constants_(std::move(params.constants)),
254       output_info_(std::move(params.output_info)) {
255   if (has_module()) {
256     XlaDebugInfoManager::Get()->RegisterModule(
257         module().unique_id(), shared_module(), debug_buffer_assignment_);
258   }
259 }
260 
~GpuExecutable()261 GpuExecutable::~GpuExecutable() {
262   if (has_module()) {
263     XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
264   }
265 
266   {
267     // We could have issued host->device mem copies in ResolveConstantGlobals.
268     // Wait for those to finish so that we can safely deallocate the backing HLO
269     // module.
270     //
271     // We need for the host->device memcpies to finish they are concurrently
272     // reading memory (xla::Literal's) owned by the HLO module.
273     absl::MutexLock lock(&module_handle_mutex_);
274     for (const auto& pair : module_globals_) {
275       CHECK(pair.first->SynchronizeAllActivity());
276     }
277   }
278 
279 #if XLA_ENABLE_XLIR
280   delete jitrt_executable_;
281 #endif
282 }
283 
CheckCompatibilityWithServiceExecutableRunOptions(const ServiceExecutableRunOptions * run_options)284 Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
285     const ServiceExecutableRunOptions* run_options) {
286   se::Stream* main_stream = run_options->stream();
287 
288   stream_executor::PlatformKind platform_kind =
289       main_stream->parent()->platform_kind();
290   if (platform_kind == stream_executor::PlatformKind::kROCm) {
291     auto cc = main_stream->GetRocmComputeCapability();
292     std::string stream_arch = cc.gcn_arch_name();
293     std::string gpu_exec_arch =
294         std::get<se::RocmComputeCapability>(gpu_version_).gcn_arch_name();
295     TF_RET_CHECK(stream_arch == gpu_exec_arch)
296         << "AMDGPU GCN ISA version mismatch; expected {" << gpu_exec_arch
297         << ", but was " << stream_arch;
298   } else if (platform_kind == stream_executor::PlatformKind::kCuda) {
299     GpuVersion cc = main_stream->GetCudaComputeCapability();
300     TF_RET_CHECK(std::get<se::CudaComputeCapability>(cc) ==
301                  std::get<se::CudaComputeCapability>(gpu_version_))
302         << "Compute capability mismatch; expected {"
303         << std::get<se::CudaComputeCapability>(gpu_version_).ToString()
304         << "}, but was {" << std::get<se::CudaComputeCapability>(cc).ToString()
305         << "}";
306   } else {
307     return InternalError("Unknown platform: %d", platform_kind);
308   }
309 
310   return OkStatus();
311 }
312 
313 namespace {
314 
315 Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
316                            uint64_t start_micros, se::Stream* stream_to_sync);
317 
ExecuteThunks(const std::string & module_name,const ThunkSequence & thunk_sequence,const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,bool block_host_until_done)318 Status ExecuteThunks(const std::string& module_name,
319                      const ThunkSequence& thunk_sequence,
320                      const ServiceExecutableRunOptions* run_options,
321                      const BufferAllocations& buffer_allocations,
322                      bool block_host_until_done) {
323   se::Stream* main_stream = run_options->stream();
324   se::StreamExecutor* executor = main_stream->parent();
325 
326   StatusOr<StreamPool::Ptr> async_comms_stream =
327       run_options->BorrowStream(executor->device_ordinal());
328 
329   uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
330 
331   tensorflow::profiler::TraceMe hlo_module_activity(
332       [&] { return absl::StrCat(module_name, ":XLA GPU module"); },
333       tensorflow::profiler::TraceMeLevel::kInfo);
334 
335   for (const std::unique_ptr<Thunk>& thunk : thunk_sequence) {
336     // Annotate execution of this op if tracing was enabled when we started
337     // running this module.  If tracing is enabled *while* we're running the
338     // module, we won't get any data, but that's probably an OK trade-off.
339     ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
340     VLOG(2) << "Executing the thunk for " << thunk->profile_annotation();
341     TF_RET_CHECK(async_comms_stream.ok() || !NeedsAsyncCommsStream(*thunk))
342         << "`run_options` must have a stream borrower for async thunks.";
343 
344     Thunk::ExecuteParams thunk_params{
345         *run_options, buffer_allocations, main_stream,
346         async_comms_stream.ok() ? async_comms_stream->get() : nullptr};
347     TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
348   }
349   return MaybeSyncAndProfile(run_options, start_micros,
350                              block_host_until_done ? main_stream : nullptr);
351 }
352 
MaybeSyncAndProfile(const ServiceExecutableRunOptions * run_options,uint64_t start_micros,se::Stream * stream_to_sync=nullptr)353 Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
354                            uint64_t start_micros,
355                            se::Stream* stream_to_sync = nullptr) {
356   // Make sure kernels are completed before deallocating temporary buffers or
357   // the profiler state.
358   // TODO(b/30100571): we could potentially postpone deallocating the temp
359   // buffers until a different computation is executed.
360   if (stream_to_sync) {
361     Status block_status = stream_to_sync->BlockHostUntilDone();
362     if (!block_status.ok()) {
363       return InternalError(
364           "Failed to complete all kernels launched on stream %p: %s",
365           stream_to_sync, block_status.error_message());
366     }
367   }
368 
369   // FinishExecution() blocks until main_stream has completed if profiling is
370   // enabled; we therefore do not need to defer profile collection onto a
371   // stream.
372   uint64_t end_micros = tensorflow::Env::Default()->NowMicros();
373 
374   if (run_options->run_options().execution_profile()) {
375     ExecutionProfile* profile = run_options->run_options().execution_profile();
376     const double nanoseconds = (end_micros - start_micros) * 1000.0;
377     profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
378   }
379 
380   return OkStatus();
381 }
382 
383 }  // namespace
384 
385 StatusOr<const GpuExecutable::BufferAllocToDeviceMemoryMap*>
ResolveConstantGlobals(se::Stream * stream)386 GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
387   se::StreamExecutor* executor = stream->parent();
388 
389   absl::MutexLock lock(&module_handle_mutex_);
390   auto it = module_globals_.find(executor);
391   if (it != module_globals_.end()) {
392     return &it->second;
393   }
394 
395   se::MultiModuleLoaderSpec module_spec;
396   if (!binary().empty()) {
397     module_spec.AddCudaCubinInMemory(binary());
398   }
399   module_spec.AddCudaPtxInMemory(text().c_str());
400 
401   absl::flat_hash_map<int64_t, se::DeviceMemoryBase> globals;
402   se::ModuleHandle module_handle;
403   // The CUDA driver isn't able to load empty PTX. It's okay if we skip loading
404   // in this case; if the module isn't loaded, all symbol lookups will fail,
405   // just as they should for an empty module.
406   if (!(executor->platform_kind() == se::PlatformKind::kCuda &&
407         module_spec.cuda_ptx_in_memory() == nullptr)) {
408     TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
409   }
410 
411   for (const ConstantInfo& info : constants_) {
412     StatusOr<stream_executor::DeviceMemoryBase> global_status;
413     if (static_cast<bool>(module_handle)) {
414       global_status =
415           executor->GetUntypedSymbol(info.symbol_name, module_handle);
416     }
417 
418     se::DeviceMemoryBase global;
419     if (static_cast<bool>(module_handle) && global_status.ok()) {
420       // The constant was defined in the PTX and has been allocated by the CUDA
421       // driver.
422       global = *global_status;
423       VLOG(3) << "Resolved global " << info.symbol_name << " to "
424               << global.opaque();
425 
426       if (!info.content.empty()) {
427         // This means the constant did not have an initializer in the PTX and
428         // therefore must be initialized by XLA here.
429         stream->ThenMemcpy(&global, info.content.data(), info.content.size());
430       }
431     } else {
432       // The constant was not defined in the PTX and therefore must be both
433       // allocated and initialized by XLA here.
434       CHECK(!info.content.empty());
435 
436       TF_ASSIGN_OR_RETURN(
437           auto shared, executor->CreateOrShareConstant(stream, info.content));
438       global = *shared;
439       VLOG(3) << "Allocated (or shared) global " << info.symbol_name << " at "
440               << global.opaque();
441       // XLA will continue to own this global at least until this executable is
442       // destroyed (longer if another, longer-lived executable shares the same
443       // constant).
444       shared_constants_.push_back(std::move(shared));
445     }
446 
447     if (info.allocation_index != -1) {
448       InsertOrDie(&globals, info.allocation_index, global);
449     }
450   }
451 
452   module_handles_.emplace(executor,
453                           se::ScopedModuleHandle(executor, module_handle));
454   return &module_globals_.emplace(executor, std::move(globals)).first->second;
455 }
456 
BufferForAllocation(VariantArguments arguments,const GpuExecutable::BufferAllocToDeviceMemoryMap * globals,const BufferAllocation & allocation,se::DeviceMemoryAllocator * const memory_allocator,int device_ordinal,int64_t arg_idx)457 StatusOr<se::DeviceMemoryBase> GpuExecutable::BufferForAllocation(
458     VariantArguments arguments,
459     const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
460     const BufferAllocation& allocation,
461     se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal,
462     int64_t arg_idx) {
463   if (allocation.is_thread_local()) {
464     return se::DeviceMemoryBase{};
465   } else if (allocation.is_entry_computation_parameter()) {
466     int64_t param_no = allocation.parameter_number();
467     se::DeviceMemoryBase registered_buffer = [&] {
468       if (auto unowned_shapedbuffers =
469               std::get_if<absl::Span<const ShapedBuffer* const>>(&arguments)) {
470         return (*unowned_shapedbuffers)[param_no]->buffers().element(
471             allocation.param_shape_index());
472       } else {
473         return std::get<absl::Span<ExecutionInput>>(arguments)[param_no]
474             .Buffer(allocation.param_shape_index())
475             .AsDeviceMemoryBase();
476       }
477     }();
478     if (registered_buffer.is_null() && registered_buffer.size() > 0) {
479       return FailedPrecondition(
480           "Cannot run XLA computation because pointer to (sub-)buffer at "
481           "index %s of parameter %d was null.  All pointers to "
482           "(sub-)buffers must not be null, unless the (sub-)buffer has "
483           "zero elements.",
484           allocation.param_shape_index().ToString(), param_no);
485     }
486     return registered_buffer;
487   } else if (allocation.is_constant()) {
488     auto it = globals->find(arg_idx);
489     if (it == globals->end()) {
490       return se::DeviceMemoryBase();
491     }
492     return it->second;
493   } else {
494     // Allocate each allocation that might escape, or is the temp buffer.
495     CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer());
496     const int64_t buffer_size = allocation.size();
497     se::DeviceMemoryBase buffer_address;
498     if (buffer_size > 0) {
499       StatusOr<se::OwningDeviceMemory> buffer =
500           memory_allocator->Allocate(device_ordinal, buffer_size);
501       if (!buffer.ok()) {
502         return ResourceExhausted("%s\n%s\n", buffer.status().error_message(),
503                                  verbose_buffer_assignment_string_dumper_());
504       }
505       buffer_address = buffer->Release();
506     }
507     return buffer_address;
508   }
509 }
510 
CheckAlignment(const BufferAllocation & allocation,se::DeviceMemoryBase buffer,int arg_idx)511 static Status CheckAlignment(const BufferAllocation& allocation,
512                              se::DeviceMemoryBase buffer, int arg_idx) {
513   const int64_t expected_alignment = [&] {
514     if (allocation.is_entry_computation_parameter()) {
515       return kEntryParameterAlignBytes;
516     } else if (allocation.is_constant()) {
517       return kConstantBufferAlignBytes;
518     } else {
519       return kXlaAllocatedBufferAlignBytes;
520     }
521   }();
522   if (!buffer.is_null() &&
523       reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment != 0) {
524     return InternalError(
525         "Address of buffer %d must be a multiple of %x, but "
526         "was %p",
527         arg_idx, expected_alignment, buffer.opaque());
528   }
529   return OkStatus();
530 }
531 
GenerateBufferAllocations(VariantArguments arguments,const GpuExecutable::BufferAllocToDeviceMemoryMap * globals,se::DeviceMemoryAllocator * const memory_allocator,int device_ordinal)532 StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
533     VariantArguments arguments,
534     const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
535     se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal) {
536   tensorflow::profiler::TraceMe hlo_module_activity(
537       [&] { return std::string("Build buffer allocations"); },
538       tensorflow::profiler::TraceMeLevel::kInfo);
539 
540   const int64_t num_buffers = allocations_.size();
541   std::vector<se::DeviceMemoryBase> buffers;
542   buffers.reserve(num_buffers);
543   for (int64_t i = 0; i < num_buffers; ++i) {
544     const BufferAllocation& allocation = allocations_[i];
545     TF_ASSIGN_OR_RETURN(
546         se::DeviceMemoryBase buffer,
547         BufferForAllocation(arguments, globals, allocation, memory_allocator,
548                             device_ordinal, i));
549     buffers.push_back(buffer);
550     TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffer, i));
551   }
552   return {{buffers, device_ordinal, memory_allocator}};
553 }
554 
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)555 StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
556     const ServiceExecutableRunOptions* run_options,
557     std::vector<ExecutionInput> arguments,
558     HloExecutionProfile* hlo_execution_profile) {
559   return ExecuteAsyncOnStreamImpl(run_options, absl::MakeSpan(arguments));
560 }
561 
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments,HloExecutionProfile * hlo_execution_profile)562 StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
563     const ServiceExecutableRunOptions* run_options,
564     absl::Span<const ShapedBuffer* const> arguments,
565     HloExecutionProfile* hlo_execution_profile) {
566   TF_ASSIGN_OR_RETURN(ExecutionOutput out,
567                       ExecuteAsyncOnStreamImpl(run_options, arguments));
568   return out.ConsumeResult();
569 }
570 
571 #if XLA_ENABLE_XLIR
ExecuteJitRt(const std::string & module_name,GpuExecutable::JitRtExecutable * jitrt_executable,const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,size_t num_allocations,bool block_host_until_done)572 static Status ExecuteJitRt(const std::string& module_name,
573                            GpuExecutable::JitRtExecutable* jitrt_executable,
574                            const ServiceExecutableRunOptions* run_options,
575                            const BufferAllocations& buffer_allocations,
576                            size_t num_allocations, bool block_host_until_done) {
577   uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
578 
579   tensorflow::profiler::TraceMe hlo_module_activity(
580       [&] { return absl::StrCat(module_name, ":XLA GPU module"); },
581       tensorflow::profiler::TraceMeLevel::kInfo);
582 
583   ScopedAnnotation annotation(
584       []() -> std::string { return "JitRtExecutable"; });
585 
586   // TODO(ezhulenev): Here we rely on implementation details of passing memrefs
587   // to the compiled kernel. We should have a nicer API to do this, without
588   // creating a vector of temporary MemrefDesc for passing operands.
589 
590   // Pack buffer allocations as executable arguments. It is guaranteed that
591   // compiled function will make a copy of all arguments and will write all
592   // results after the call to `Execute` completes, so it is safe to keep in on
593   // the stack.
594   runtime::Executable::CallFrame call_frame;
595 
596   // Each buffer allocation pased as 1d memref to the compiled kernel:
597   //   {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]}
598   size_t num_args_ptrs = 1 + num_allocations * 5;
599   call_frame.args.resize_for_overwrite(num_args_ptrs);
600 
601   // Pass pointers to these constants as a memref offset and stride.
602   int64_t zero = 0;
603   int64_t one = 1;
604   void* offset = &zero;
605   void* stride = &one;
606 
607   // Add a placeholder for the kernel context as the first argument.
608   call_frame.args[0] = nullptr;
609 
610   // Storage for data pointers.
611   llvm::SmallVector<void*, 16> ptrs;
612   ptrs.resize_for_overwrite(num_allocations);
613 
614   // Initialize arguments for the buffer operands.
615   for (unsigned i = 0; i < num_allocations; ++i) {
616     void* data = &(ptrs[i] = buffer_allocations.GetDeviceAddress(i).opaque());
617     void* size = const_cast<int64_t*>(&jitrt_executable->buffer_size(i));
618     unsigned idx = 1 + i * 5;
619     call_frame.args[idx + 0] = data;
620     call_frame.args[idx + 1] = data;
621     call_frame.args[idx + 2] = offset;
622     call_frame.args[idx + 3] = size;
623     call_frame.args[idx + 4] = stride;
624   }
625 
626   // JitRt executables do not return any values.
627   runtime::NoResultConverter converter;
628 
629   // Prepare options for executing JitRt program.
630   runtime::Executable::ExecuteOpts opts;
631 
632   // We don't expect to see any async tasks in the JitRt executable.
633   opts.async_task_runner =
634       reinterpret_cast<runtime::AsyncTaskRunner*>(0XDEADBEEF);
635 
636   // Get the async communications stream for async collectives.
637   int device_ordinal = run_options->stream()->parent()->device_ordinal();
638   StatusOr<StreamPool::Ptr> async_comms_stream =
639       run_options->BorrowStream(device_ordinal);
640 
641   // Async collective support instantiated for each Gpu executable run, so that
642   // concurrent executions can run independenty using a separate set of events
643   // for communication.
644   JitRtAsyncCollectiveSupport async_collectives(
645       async_comms_stream.ok() ? async_comms_stream->get() : nullptr);
646 
647   // Pass auxiliary data to the custom call handlers.
648   runtime::CustomCall::UserData user_data;
649   user_data.insert_all(
650       run_options, &jitrt_executable->debug_options(),
651       &jitrt_executable->kernels_cache(),
652       &jitrt_executable->gemm_configs_cache(), &jitrt_executable->collectives(),
653       async_collectives.async_comm_stream() ? &async_collectives : nullptr);
654   opts.custom_call_data = &user_data;
655 
656   // Collect all emitted diagnostic messages.
657   runtime::DiagnosticEngine diagnostic_engine;
658   std::string diagnostic;
659   diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
660     llvm::raw_string_ostream(diagnostic) << d.str();
661     return mlir::success();
662   });
663 
664   opts.diagnostic_engine = &diagnostic_engine;
665 
666   // Execute with the prepared call frame.
667   runtime::Executable& executable = jitrt_executable->executable();
668   executable.Execute(call_frame, opts);
669 
670   if (auto err = executable.ReturnResults(converter, &call_frame)) {
671     return InternalError(
672         "Failed to execute JitRt executable: %s.",
673         tfrt::StrCat(err,
674                      diagnostic.empty() ? "" : tfrt::StrCat(": ", diagnostic)));
675   }
676 
677   return MaybeSyncAndProfile(
678       run_options, start_micros,
679       block_host_until_done ? run_options->stream() : nullptr);
680 }
681 #endif  // XLA_ENABLE_XLIR
682 
ExecuteAsyncOnStreamImpl(const ServiceExecutableRunOptions * run_options,VariantArguments arguments)683 StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
684     const ServiceExecutableRunOptions* run_options,
685     VariantArguments arguments) {
686   XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
687       "GpuExecutable::ExecuteAsyncOnStreamImpl(", module_name_, ")"));
688   se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
689   // Force synchronous execution if the allocator requires it.
690   const bool block_host_until_done =
691       !memory_allocator->AllowsAsynchronousDeallocation();
692 
693   se::StreamExecutor* executor = run_options->stream()->parent();
694 
695   // Lock the GPU with a shared lock so that we don't interfere with autotuning
696   // that may be running during JIT compilation while allowing multiple XLA
697   // computations to use the same GPU simultaneously.
698   absl::ReaderMutexLock gpu_lock(&GetGpuMutex(executor));
699 
700   const GpuExecutable::BufferAllocToDeviceMemoryMap* globals;
701   {
702     tensorflow::profiler::TraceMe hlo_module_activity(
703         [&] { return std::string("Resolve constant globals"); },
704         tensorflow::profiler::TraceMeLevel::kInfo);
705 
706     TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream()));
707   }
708 
709   auto device_ordinal = executor->device_ordinal();
710   ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator,
711                          device_ordinal);
712 
713   TF_ASSIGN_OR_RETURN(
714       BufferAllocations buffer_allocations,
715       GenerateBufferAllocations(arguments, globals, memory_allocator,
716                                 device_ordinal));
717   VLOG(2) << buffer_allocations.ToString();
718   std::set<se::DeviceMemoryBase> buffers_in_result;
719 
720   const bool is_entire_tuple_contents_aliased = [&] {
721     for (auto& p : result.MutableResult()->buffers().leaves()) {
722       if (!output_info_.contains(p.first)) {
723         continue;
724       }
725       const OutputInfo& output_info = output_info_.at(p.first);
726       if (!output_info.alias_config.has_value()) {
727         return false;
728       }
729     }
730     return true;
731   }();
732 
733   for (auto& p : result.MutableResult()->buffers()) {
734     const ShapeIndex& index = p.first;
735     if (!output_info_.contains(index)) {
736       continue;
737     }
738     const OutputInfo& output_info = output_info_.at(index);
739     const BufferAllocation* allocation =
740         &allocations_[output_info.allocation_index];
741     se::DeviceMemoryBase& result_buffer = p.second;
742 
743     VLOG(4) << "Looking at: allocation " << output_info.allocation_index
744             << " @ index: " << index.ToString();
745 
746     if (output_info.alias_config) {
747       MaybeOwningDeviceMemory* maybe_owning_memory =
748           [&]() -> xla::MaybeOwningDeviceMemory* {
749         // ScopedBuffer is never an owned buffer.
750         if (auto* unowned_shapedbuffers =
751                 std::get_if<absl::Span<const ShapedBuffer* const>>(
752                     &arguments)) {
753           return nullptr;
754         } else {
755           auto unowned_execution_input =
756               std::get<absl::Span<ExecutionInput>>(arguments);
757           ExecutionInput& input =
758               unowned_execution_input[allocation->parameter_number()];
759           return input.MutableBuffer(allocation->param_shape_index());
760         }
761       }();
762       if (output_info.alias_config->must_alias() && maybe_owning_memory &&
763           !maybe_owning_memory->HasOwnership()) {
764         return InvalidArgument(
765             "An input was configured to be must-alias at "
766             "compile time but not donated at runtime: allocation %d",
767             output_info.allocation_index);
768       }
769       if (maybe_owning_memory && maybe_owning_memory->HasOwnership()) {
770         std::optional<tensorflow::se::OwningDeviceMemory> owning =
771             maybe_owning_memory->Release();
772         // If the caller passes the ownership of the device memory, reuse it
773         // as the output buffer. It is up to the caller whether or not to
774         // donate a buffer; the aliasing information describes which buffers
775         // may alias, not buffers that must alias.
776         se::DeviceMemoryBase argument_buffer = owning->Release();
777         *maybe_owning_memory = argument_buffer;
778         result_buffer = argument_buffer;
779         // The caller is giving us the
780         // input buffer, but in case of error from the execute call, we should
781         // not be releasing it as it contains valid data (for example, it is a
782         // parameter which the user wants us to alias, in a gradient update
783         // computation). So we store the index into the result in the aliased
784         // vector, which will be fed to the ExecutionOutput, which will use
785         // the indices to drop the addresses from its own ScopedShapedBuffer
786         // result, if the ExecutionOutput is not committed.
787         result.AddAliasedIndex(index);
788       } else if (!output_info.passthrough &&
789                  !ShapeUtil::GetSubshape(output_shape_, index).IsTuple()) {
790         // The guard is above is not to insert copy-protection when aliasing
791         // pass-through params, as we do not need to write into the output
792         // buffer.
793         VLOG(3) << "Using copy-protection: aliasing is specified, but the "
794                    "buffer is not donated; allocating a fresh buffer";
795         int64_t allocation_size =
796             ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(output_shape_, index));
797         StatusOr<se::OwningDeviceMemory> allocated_buffer =
798             memory_allocator->Allocate(device_ordinal, allocation_size);
799         if (!allocated_buffer.ok()) {
800           return ResourceExhausted("%s\n%s\n",
801                                    allocated_buffer.status().error_message(),
802                                    verbose_buffer_assignment_string_dumper_());
803         }
804         result_buffer = allocated_buffer->Release();
805         se::DeviceMemoryBase& aliased_buffer =
806             buffer_allocations.GetMutableDeviceAddress(
807                 output_info.allocation_index);
808         CHECK_EQ(aliased_buffer.size(), result_buffer.size());
809         run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer,
810                                              aliased_buffer.size());
811         aliased_buffer = result_buffer;
812       }
813     }
814 
815     if (result_buffer.is_null()) {
816       // The source instruction should have a non-parameter buffer
817       // assigned.
818       result_buffer =
819           buffer_allocations.GetDeviceAddress(output_info.allocation_index);
820 
821       // If the entire tuple contents is aliased, the copy insertion will *not*
822       // materialize a new tuple, so we mark it as aliased as well.
823       if (is_entire_tuple_contents_aliased) {
824         result.AddAliasedIndex(index);
825       }
826     }
827     buffers_in_result.insert(result_buffer);
828   }
829 
830   TF_RETURN_IF_ERROR(ExecuteThunksOrJitRt(run_options, buffer_allocations,
831                                           block_host_until_done));
832 
833   // Free all temporary allocations.
834   TF_RETURN_IF_ERROR(
835       buffer_allocations.TearDown(buffers_in_result, allocations_));
836 
837   // Free allocations for arguments.
838   if (auto args = std::get_if<absl::Span<ExecutionInput>>(&arguments)) {
839     MarkToBeReleasedArguments(*args, result);
840   }
841   return std::move(result);
842 }
843 
ExecuteThunksOrJitRt(const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,bool block_host_until_done)844 Status GpuExecutable::ExecuteThunksOrJitRt(
845     const ServiceExecutableRunOptions* run_options,
846     const BufferAllocations& buffer_allocations, bool block_host_until_done) {
847   TF_RETURN_IF_ERROR(
848       CheckCompatibilityWithServiceExecutableRunOptions(run_options));
849 
850   if (thunks_) {
851     se::StreamExecutor* executor = run_options->stream()->parent();
852     for (const std::unique_ptr<Thunk>& thunk : *thunks_) {
853       TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
854     }
855     return ExecuteThunks(module_name_, *thunks_, run_options,
856                          buffer_allocations, block_host_until_done);
857   }
858 
859 #if XLA_ENABLE_XLIR
860   if (jitrt_executable_) {
861     return ExecuteJitRt(module_name_, jitrt_executable_, run_options,
862                         buffer_allocations, allocations_.size(),
863                         block_host_until_done);
864   }
865 #endif  // XLA_ENABLE_XLIR
866 
867   return FailedPrecondition("Expected XLA gpu executable is not supplied.");
868 }
869 
SizeOfGeneratedCodeInBytes() const870 int64_t GpuExecutable::SizeOfGeneratedCodeInBytes() const {
871   // Non-empty PTX but empty cubin: compilation must have failed, return
872   // "unknown".
873   if (binary().empty() && !text_.empty()) {
874     return -1;
875   }
876   int64_t size = binary().size();
877   for (BufferAllocation::Index i = 0; i < allocations_.size(); ++i) {
878     const BufferAllocation& allocation = allocations_[i];
879     if (allocation.is_constant()) {
880       size += allocation.size();
881     }
882   }
883   return size;
884 }
885 
SetUpMlirAllocation(mlir::func::FuncOp func,llvm::ArrayRef<int64_t> buffer_sizes,std::vector<BufferAllocation> * allocations,absl::flat_hash_map<ShapeIndex,GpuExecutable::OutputInfo> * output_info,Shape * output_shape,int buffer_param_offset)886 Status GpuExecutable::SetUpMlirAllocation(
887     mlir::func::FuncOp func, llvm::ArrayRef<int64_t> buffer_sizes,
888     std::vector<BufferAllocation>* allocations,
889     absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
890     Shape* output_shape, int buffer_param_offset) {
891   for (int i = 0; i < buffer_sizes.size(); i++) {
892     allocations->emplace_back(i, buffer_sizes[i], 0);
893   }
894 
895   for (int i = 0; i < func.getNumArguments(); i++) {
896     if (i < buffer_param_offset) {
897       continue;
898     }
899     const int buffer_index = i - buffer_param_offset;
900 
901     if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
902       xla::ShapeIndex shape_index;
903       if (auto shape_index_attr =
904               func.getArgAttrOfType<mlir::DenseIntElementsAttr>(
905                   i, "lmhlo.param_shape_index")) {
906         for (const llvm::APInt& element : shape_index_attr) {
907           shape_index.push_back(element.getSExtValue());
908         }
909       }
910       allocations->at(buffer_index)
911           .set_entry_computation_parameter(
912               param_attr.cast<mlir::IntegerAttr>().getInt(), shape_index,
913               static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
914     }
915     // TODO(timshen): this information is redundant. This is here only for
916     // smooth migration to LMHLO. Remove it.
917     if (func.getArgAttr(i, "lmhlo.constant_name")) {
918       allocations->at(buffer_index).set_constant(true);
919     }
920     if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
921       allocations->at(buffer_index).set_maybe_live_out(true);
922 
923       // Reconstruct a shape index from output_index.
924       ShapeIndex shape_index;
925       for (const llvm::APInt& element :
926            output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
927         shape_index.push_back(element.getSExtValue());
928       }
929       auto& o = (*output_info)[shape_index];
930       o.allocation_index = buffer_index;
931       if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
932         HloInputOutputAliasConfig::AliasKind kind =
933             HloInputOutputAliasConfig::kMayAlias;
934         if (func.getArgAttr(i, "lmhlo.must_alias")) {
935           kind = HloInputOutputAliasConfig::kMustAlias;
936         }
937         o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
938                                ShapeIndex{}, kind);
939       }
940       if (func.getArgument(i).use_empty()) {
941         o.passthrough = true;
942       }
943     }
944   }
945   // Expects result_xla_shape as a XLA shape in string form.
946   //
947   // The attribute is necessary, because GpuExecutable/ExecutionOutput supports
948   // tuples / tree-like shapes, while the LMHLO argument list loses the tree
949   // form.
950   //
951   // The string format is necessary since MLIR doesn't support XLA shape with
952   // dynamic_dimension.
953   //
954   // TODO(timshen): now this field is mandatory. Make it optional for
955   // non-GpuExecutable outputs.
956   TF_ASSIGN_OR_RETURN(
957       *output_shape,
958       ParseShape(func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
959                      .getValue()
960                      .str()));
961 
962   return OkStatus();
963 }
964 
965 StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
GetOutputInfo(const HloModule & hlo_module,const BufferAssignment & assignment)966 GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
967   const HloInstruction* root =
968       hlo_module.entry_computation()->root_instruction();
969 
970   InstructionValueSet root_value_set =
971       assignment.dataflow_analysis().GetInstructionValueSet(root);
972 
973   if (root_value_set.IsAmbiguous()) {
974     return Unimplemented("Points-to set of root instruction is ambiguous");
975   }
976 
977   using OutputInfoMap =
978       absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
979   OutputInfoMap output;
980   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
981       root->shape(),
982       [&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> Status {
983         const auto& sources = root_value_set.element(index);
984         // The points-to set is unambiguous so the set should be a
985         // singleton. That is, we know exactly which instruction
986         // produced the array at this element.
987         CHECK_EQ(1, sources.values().size());
988         HloInstruction* src_hlo = sources.values()[0]->instruction();
989 
990         GpuExecutable::OutputInfo& info = output[index];
991         info.passthrough = src_hlo->opcode() == HloOpcode::kParameter;
992         TF_ASSIGN_OR_RETURN(
993             const BufferAllocation::Slice slice,
994             assignment.GetUniqueSlice(src_hlo, sources.values()[0]->index()));
995         CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
996         info.allocation_index = slice.index();
997 
998         output[index].alias_config =
999             hlo_module.input_output_alias_config().GetAliasedParameter(index);
1000 
1001         return OkStatus();
1002       }));
1003   return output;
1004 }
1005 
GpuExecutable(std::shared_ptr<HloModule> hlo_module,GpuVersion gpu_version,xla::EntryFunctionAttributes entry_func_attrs,absl::string_view module_name,Shape xla_output_shape,std::vector<BufferAllocation> allocations,absl::flat_hash_map<ShapeIndex,OutputInfo> output_info,JitRtExecutable * jitrt_executable)1006 GpuExecutable::GpuExecutable(
1007     std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
1008     xla::EntryFunctionAttributes entry_func_attrs,
1009     absl::string_view module_name, Shape xla_output_shape,
1010     std::vector<BufferAllocation> allocations,
1011     absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
1012     JitRtExecutable* jitrt_executable)
1013     : Executable(std::move(hlo_module)),
1014       gpu_version_(gpu_version),
1015       entry_func_attrs_(entry_func_attrs),
1016       module_name_(module_name),
1017       output_shape_(xla_output_shape),
1018       allocations_(std::move(allocations)),
1019       output_info_(std::move(output_info)),
1020       jitrt_executable_(jitrt_executable) {
1021   XlaDebugInfoManager::Get()->RegisterModule(
1022       module().unique_id(), shared_module(), debug_buffer_assignment_);
1023 }
1024 
LoadFromObjFile(std::shared_ptr<HloModule> hlo_module,absl::string_view obj_file,absl::string_view mlir_module,xla::EntryFunctionAttributes entry_func_attrs,DebugOptions debug_options,GpuVersion gpu_version,se::StreamExecutor * executor)1025 StatusOr<std::unique_ptr<Executable>> GpuExecutable::LoadFromObjFile(
1026     std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
1027     absl::string_view mlir_module,
1028     xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
1029     GpuVersion gpu_version, se::StreamExecutor* executor) {
1030 #if XLA_ENABLE_XLIR
1031   // Load MLIR module behind the compiled object file to recover XLA allocations
1032   // and output info details. Also recover buffer sizes from the entrypoint
1033   // function signature.
1034   mlir::MLIRContext context;
1035 
1036   mlir::DialectRegistry registry;
1037   tfrt::RegisterTFRTDialects(registry);
1038   tfrt::RegisterTFRTCompiledDialects(registry);
1039   context.appendDialectRegistry(registry);
1040 
1041   auto module = mlir::parseSourceString<mlir::ModuleOp>(mlir_module, &context);
1042   if (!module) return InternalError("Failed to parse AOT compiled module");
1043 
1044   // Get the XLA module entrypoint function.
1045   auto func = mlir::cast<mlir::func::FuncOp>(
1046       module->lookupSymbol(hlo_module->entry_computation()->name()));
1047 
1048   // Get the buffer sizes from the entrypoint function signature.
1049   std::vector<int64_t> buffer_sizes;
1050   buffer_sizes.reserve(func.getNumArguments());
1051   for (auto type : func.getArgumentTypes()) {
1052     auto memref = type.dyn_cast<mlir::MemRefType>();
1053     if (!memref || !memref.hasStaticShape() || memref.getRank() != 1)
1054       return InternalError("Illegal entrypoint argument type: %s",
1055                            tfrt::StrCat(type));
1056     buffer_sizes.push_back(memref.getDimSize(0));
1057   }
1058 
1059   // Infer XLA allocations and output info from the MLIR module.
1060   std::vector<BufferAllocation> allocations;
1061   absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
1062   Shape result_xla_shape;
1063   TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations,
1064                                          &output_info, &result_xla_shape,
1065                                          /*buffer_param_offset=*/0));
1066 
1067   // Create a named buffer from compiled object file.
1068   llvm::StringRef data(obj_file.data(), obj_file.size());
1069   auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
1070 
1071   // Create a JitRt function signature (all arguments passed as 1d memrefs).
1072   llvm::SmallVector<std::unique_ptr<runtime::Type>> args;
1073   llvm::SmallVector<std::unique_ptr<runtime::Type>> rt_args;
1074   rt_args.push_back(std::make_unique<runtime::KernelContextOperandType>());
1075 
1076   for (int64_t size : buffer_sizes) {
1077     auto i8 = tfrt::DType::I8;
1078     args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
1079     rt_args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
1080   }
1081 
1082   runtime::FunctionType signature(std::move(args), /*results=*/{});
1083   runtime::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
1084 
1085   auto symbol_map =
1086       runtime::ToSymbolsBinding(JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
1087 
1088   // Load JitRt executable from an object file, and link it with Gpu runtime
1089   // intrinsics implementing Gpu custom calls.
1090   auto executable = runtime::Executable::LoadFromObjFile(
1091       hlo_module->name(), std::move(buffer),
1092       hlo_module->entry_computation()->name(), std::move(signature),
1093       std::move(rt_signature), symbol_map);
1094   if (auto err = executable.takeError())
1095     return InternalError("Failed to load JitRt executable: %s",
1096                          tfrt::StrCat(err));
1097 
1098   // Move runtime::Executable ownership to the JitRtExecutable.
1099   TF_ASSIGN_OR_RETURN(
1100       JitRtExecutable * jitrt_executable,
1101       JitRtExecutable::Create(buffer_sizes, std::move(*executable),
1102                               std::move(debug_options)));
1103 
1104   // Construct GpuExecutable for the loaded JitRt executable.
1105   std::string name = hlo_module->name();
1106   return std::unique_ptr<Executable>(
1107       new GpuExecutable(std::move(hlo_module), gpu_version, entry_func_attrs,
1108                         name, result_xla_shape, std::move(allocations),
1109                         std::move(output_info), jitrt_executable));
1110 
1111 #else   // XLA_ENABLE_XLIR
1112   return FailedPrecondition("Not built with XLA_ENABLE_XLIR");
1113 #endif  // XLA_ENABLE_XLIR
1114 }
1115 }  // namespace gpu
1116 }  // namespace xla
1117