1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <variant>
25 #include <vector>
26
27 #include "absl/cleanup/cleanup.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/synchronization/mutex.h"
30 #include "mlir/Parser/Parser.h" // from @llvm-project
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
36 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
42 #include "tensorflow/compiler/xla/service/transfer_manager.h"
43 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
44 #include "tensorflow/compiler/xla/shape_tree.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/casts.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/stream_executor/platform.h"
55
56 #if XLA_ENABLE_XLIR
57 #include "tensorflow/compiler/xla/mlir/transforms/runtime/compilation_pipeline.h"
58 #include "tensorflow/compiler/xla/runtime/diagnostics.h"
59 #include "tensorflow/compiler/xla/runtime/executable.h"
60 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
61 #include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
62 #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime
63 #endif // XLA_ENABLE_XLIR
64
65 namespace xla {
66 namespace gpu {
67
IsJitRtExecutableEnabled(const HloModuleConfig & config)68 bool IsJitRtExecutableEnabled(const HloModuleConfig& config) {
69 #if !XLA_ENABLE_XLIR
70 CHECK(!config.debug_options().xla_gpu_jitrt_executable())
71 << "Failed to enable JitRt backend, because it was not compiled.";
72 #endif // !XLA_ENABLE_XLIR
73 return config.debug_options().xla_gpu_jitrt_executable();
74 }
75
76 namespace {
77
78 using ::tensorflow::profiler::ScopedAnnotation;
79
NeedsAsyncCommsStream(Thunk & thunk)80 bool NeedsAsyncCommsStream(Thunk& thunk) {
81 switch (thunk.kind()) {
82 case Thunk::Kind::kNcclAllReduceStart:
83 case Thunk::Kind::kNcclAllReduceDone:
84 return true;
85 default:
86 return false;
87 }
88 }
89
90 } // namespace
91
92 #if XLA_ENABLE_XLIR
93
94 class GpuExecutable::JitRtExecutable {
95 public:
Create(OwnedJitRtProgram program)96 static StatusOr<JitRtExecutable*> Create(OwnedJitRtProgram program) {
97 // Options for the default JitRt compilation pipeline.
98 runtime::CompilationPipelineOptions copts;
99
100 // Populate mapping from XLA (SE) enums/structs type id to symbol names.
101 copts.populate_type_id_names = PopulateXlaTypeIdNames;
102
103 // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls.
104 copts.populate_attr_encodings = PopulateLmhloToXlaAttrEncoding;
105
106 // Options for constructing XLA runtime JitExecutable.
107 runtime::JitExecutable::Options opts;
108 opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
109 opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
110 runtime::RegisterDefaultXlaRuntimeDialects(registry);
111 // For the encoding of attributes to custom calls.
112 registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
113 };
114
115 // Register XLA Gpu runtime custom calls with the linker.
116 opts.compiler.symbols_binding = runtime::ToSymbolsBinding(
117 JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
118
119 // We just use the default compilation pipeline provided by the XLA runtime.
120 // Alternatively instead of having a separate JitRtProgram (LMHLO lowered to
121 // canonical dialects), we can assemble a pipeline that will compile
122 // starting from the LMHLO dialect. However this intermediate step helps
123 // with debugging, by materializing IR with XLA runtime custom calls.
124 opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
125 runtime::CreateDefaultXlaRuntimeCompilationPipeline(pm, copts);
126 };
127
128 // TODO(b/241296710): LLVM optimizations interact badly with the memory
129 // loads and stores pattern generated in very large XLA programs, and can
130 // take minutes to run. Currently we do not expect any expensive code
131 // running on the host, so we can safely disable optimization passes.
132 opts.compiler.jit_code_opt_level = llvm::CodeGenOpt::None;
133
134 // Instantiate new JitExecutable from the MLIR source.
135 auto jit_executable = runtime::JitExecutable::Instantiate(
136 program->module, program->entry_point, opts);
137 if (auto err = jit_executable.takeError())
138 return InternalError("Failed to compile JitRt program: %s",
139 tfrt::StrCat(err));
140
141 // Pass ownership to the GpuExecutable.
142 return new JitRtExecutable(
143 std::move(program->buffer_sizes),
144 std::make_unique<runtime::JitExecutable>(std::move(*jit_executable)),
145 std::move(program->debug_options));
146 }
147
148 // Create JitRtExecutable from the AOT compiled binary.
Create(absl::Span<const int64_t> buffer_sizes,runtime::Executable executable,DebugOptions debug_options)149 static StatusOr<JitRtExecutable*> Create(
150 absl::Span<const int64_t> buffer_sizes, runtime::Executable executable,
151 DebugOptions debug_options) {
152 // Pass ownership to the GpuExecutable.
153 return new JitRtExecutable(
154 std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
155 std::make_unique<runtime::Executable>(std::move(executable)),
156 std::move(debug_options));
157 }
158
kernels_cache()159 JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
gemm_configs_cache()160 JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
collectives()161 JitRtCollectiveSupport& collectives() { return collectives_; }
162
executable()163 runtime::Executable& executable() {
164 // Exactly one kind of `Executable` should be available at run time.
165 assert((default_executable_ || executable_) &&
166 !(default_executable_ && executable_));
167 return default_executable_ ? *default_executable_ : *executable_;
168 }
169
170 // We pass a pointer to the buffer size to the compiled function, so we return
171 // a reference to a stable memory location.
buffer_size(size_t offset) const172 const int64_t& buffer_size(size_t offset) const {
173 return buffer_sizes_[offset];
174 }
175
debug_options() const176 const DebugOptions& debug_options() const { return debug_options_; }
177
178 private:
JitRtExecutable(std::vector<int64_t> buffer_sizes,std::unique_ptr<runtime::JitExecutable> jit_executable,DebugOptions debug_options)179 JitRtExecutable(std::vector<int64_t> buffer_sizes,
180 std::unique_ptr<runtime::JitExecutable> jit_executable,
181 DebugOptions debug_options)
182 : buffer_sizes_(std::move(buffer_sizes)),
183 jit_executable_(std::move(jit_executable)),
184 default_executable_(&jit_executable_->DefaultExecutable().get()),
185 debug_options_(std::move(debug_options)) {}
186
JitRtExecutable(std::vector<int64_t> buffer_sizes,std::unique_ptr<runtime::Executable> aot_executable,DebugOptions debug_options)187 JitRtExecutable(std::vector<int64_t> buffer_sizes,
188 std::unique_ptr<runtime::Executable> aot_executable,
189 DebugOptions debug_options)
190 : buffer_sizes_(std::move(buffer_sizes)),
191 aot_executable_(std::move(aot_executable)),
192 executable_(aot_executable_.get()),
193 debug_options_(std::move(debug_options)) {}
194
195 std::vector<int64_t> buffer_sizes_;
196
197 // In JIT compilation mode the `JitExecutable` owns the default `Executable`.
198 std::unique_ptr<runtime::JitExecutable> jit_executable_;
199 runtime::Executable* default_executable_ = nullptr;
200
201 // In AOT compilation mode we directly own the `Executable`.
202 std::unique_ptr<runtime::Executable> aot_executable_;
203 runtime::Executable* executable_ = nullptr;
204
205 DebugOptions debug_options_;
206
207 // Keep a cache of kernels instantiated by this executable.
208 JitRtKernelsCache kernels_cache_;
209
210 // Keep a cache of gemm configs for all gemm operation in the program.
211 JitRtGemmConfigCache gemm_configs_cache_;
212
213 // Support for running collective operations.
214 JitRtCollectiveSupport collectives_;
215 };
216 #endif // XLA_ENABLE_XLIR
217
Create(Params params)218 StatusOr<std::unique_ptr<GpuExecutable>> GpuExecutable::Create(Params params) {
219 auto executable = std::move(params.executable);
220 std::unique_ptr<GpuExecutable> result(new GpuExecutable(std::move(params)));
221
222 if (std::holds_alternative<OwnedThunkSequence>(executable)) {
223 result->thunks_ = std::move(std::get<OwnedThunkSequence>(executable));
224 return result;
225 }
226
227 #if XLA_ENABLE_XLIR
228 if (std::holds_alternative<OwnedJitRtProgram>(executable)) {
229 auto& program = std::get<OwnedJitRtProgram>(executable);
230 TF_ASSIGN_OR_RETURN(result->jitrt_executable_,
231 JitRtExecutable::Create(std::move(program)));
232 return result;
233 }
234 #endif // XLA_ENABLE_XLIR
235
236 return InternalError("No XLA gpu executable was provided");
237 }
238
239 // Implementation note: HLO profiling is always enabled for GPU executables,
240 // since we can use timers around thunks.
GpuExecutable(GpuExecutable::Params params)241 GpuExecutable::GpuExecutable(GpuExecutable::Params params)
242 : Executable(std::move(params.debug_module)),
243 text_(std::move(params.asm_text)),
244 binary_(std::move(params.binary)),
245 gpu_version_(params.gpu_version),
246 entry_func_attrs_(params.entry_func_attrs),
247 module_name_(params.module_name),
248 output_shape_(params.output_shape),
249 allocations_(std::move(params.allocations)),
250 debug_buffer_assignment_(std::move(params.debug_buffer_assignment)),
251 verbose_buffer_assignment_string_dumper_(
252 params.verbose_buffer_assignment_string_dumper),
253 constants_(std::move(params.constants)),
254 output_info_(std::move(params.output_info)) {
255 if (has_module()) {
256 XlaDebugInfoManager::Get()->RegisterModule(
257 module().unique_id(), shared_module(), debug_buffer_assignment_);
258 }
259 }
260
~GpuExecutable()261 GpuExecutable::~GpuExecutable() {
262 if (has_module()) {
263 XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
264 }
265
266 {
267 // We could have issued host->device mem copies in ResolveConstantGlobals.
268 // Wait for those to finish so that we can safely deallocate the backing HLO
269 // module.
270 //
271 // We need for the host->device memcpies to finish they are concurrently
272 // reading memory (xla::Literal's) owned by the HLO module.
273 absl::MutexLock lock(&module_handle_mutex_);
274 for (const auto& pair : module_globals_) {
275 CHECK(pair.first->SynchronizeAllActivity());
276 }
277 }
278
279 #if XLA_ENABLE_XLIR
280 delete jitrt_executable_;
281 #endif
282 }
283
CheckCompatibilityWithServiceExecutableRunOptions(const ServiceExecutableRunOptions * run_options)284 Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
285 const ServiceExecutableRunOptions* run_options) {
286 se::Stream* main_stream = run_options->stream();
287
288 stream_executor::PlatformKind platform_kind =
289 main_stream->parent()->platform_kind();
290 if (platform_kind == stream_executor::PlatformKind::kROCm) {
291 auto cc = main_stream->GetRocmComputeCapability();
292 std::string stream_arch = cc.gcn_arch_name();
293 std::string gpu_exec_arch =
294 std::get<se::RocmComputeCapability>(gpu_version_).gcn_arch_name();
295 TF_RET_CHECK(stream_arch == gpu_exec_arch)
296 << "AMDGPU GCN ISA version mismatch; expected {" << gpu_exec_arch
297 << ", but was " << stream_arch;
298 } else if (platform_kind == stream_executor::PlatformKind::kCuda) {
299 GpuVersion cc = main_stream->GetCudaComputeCapability();
300 TF_RET_CHECK(std::get<se::CudaComputeCapability>(cc) ==
301 std::get<se::CudaComputeCapability>(gpu_version_))
302 << "Compute capability mismatch; expected {"
303 << std::get<se::CudaComputeCapability>(gpu_version_).ToString()
304 << "}, but was {" << std::get<se::CudaComputeCapability>(cc).ToString()
305 << "}";
306 } else {
307 return InternalError("Unknown platform: %d", platform_kind);
308 }
309
310 return OkStatus();
311 }
312
313 namespace {
314
315 Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
316 uint64_t start_micros, se::Stream* stream_to_sync);
317
ExecuteThunks(const std::string & module_name,const ThunkSequence & thunk_sequence,const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,bool block_host_until_done)318 Status ExecuteThunks(const std::string& module_name,
319 const ThunkSequence& thunk_sequence,
320 const ServiceExecutableRunOptions* run_options,
321 const BufferAllocations& buffer_allocations,
322 bool block_host_until_done) {
323 se::Stream* main_stream = run_options->stream();
324 se::StreamExecutor* executor = main_stream->parent();
325
326 StatusOr<StreamPool::Ptr> async_comms_stream =
327 run_options->BorrowStream(executor->device_ordinal());
328
329 uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
330
331 tensorflow::profiler::TraceMe hlo_module_activity(
332 [&] { return absl::StrCat(module_name, ":XLA GPU module"); },
333 tensorflow::profiler::TraceMeLevel::kInfo);
334
335 for (const std::unique_ptr<Thunk>& thunk : thunk_sequence) {
336 // Annotate execution of this op if tracing was enabled when we started
337 // running this module. If tracing is enabled *while* we're running the
338 // module, we won't get any data, but that's probably an OK trade-off.
339 ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
340 VLOG(2) << "Executing the thunk for " << thunk->profile_annotation();
341 TF_RET_CHECK(async_comms_stream.ok() || !NeedsAsyncCommsStream(*thunk))
342 << "`run_options` must have a stream borrower for async thunks.";
343
344 Thunk::ExecuteParams thunk_params{
345 *run_options, buffer_allocations, main_stream,
346 async_comms_stream.ok() ? async_comms_stream->get() : nullptr};
347 TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
348 }
349 return MaybeSyncAndProfile(run_options, start_micros,
350 block_host_until_done ? main_stream : nullptr);
351 }
352
MaybeSyncAndProfile(const ServiceExecutableRunOptions * run_options,uint64_t start_micros,se::Stream * stream_to_sync=nullptr)353 Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options,
354 uint64_t start_micros,
355 se::Stream* stream_to_sync = nullptr) {
356 // Make sure kernels are completed before deallocating temporary buffers or
357 // the profiler state.
358 // TODO(b/30100571): we could potentially postpone deallocating the temp
359 // buffers until a different computation is executed.
360 if (stream_to_sync) {
361 Status block_status = stream_to_sync->BlockHostUntilDone();
362 if (!block_status.ok()) {
363 return InternalError(
364 "Failed to complete all kernels launched on stream %p: %s",
365 stream_to_sync, block_status.error_message());
366 }
367 }
368
369 // FinishExecution() blocks until main_stream has completed if profiling is
370 // enabled; we therefore do not need to defer profile collection onto a
371 // stream.
372 uint64_t end_micros = tensorflow::Env::Default()->NowMicros();
373
374 if (run_options->run_options().execution_profile()) {
375 ExecutionProfile* profile = run_options->run_options().execution_profile();
376 const double nanoseconds = (end_micros - start_micros) * 1000.0;
377 profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
378 }
379
380 return OkStatus();
381 }
382
383 } // namespace
384
385 StatusOr<const GpuExecutable::BufferAllocToDeviceMemoryMap*>
ResolveConstantGlobals(se::Stream * stream)386 GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
387 se::StreamExecutor* executor = stream->parent();
388
389 absl::MutexLock lock(&module_handle_mutex_);
390 auto it = module_globals_.find(executor);
391 if (it != module_globals_.end()) {
392 return &it->second;
393 }
394
395 se::MultiModuleLoaderSpec module_spec;
396 if (!binary().empty()) {
397 module_spec.AddCudaCubinInMemory(binary());
398 }
399 module_spec.AddCudaPtxInMemory(text().c_str());
400
401 absl::flat_hash_map<int64_t, se::DeviceMemoryBase> globals;
402 se::ModuleHandle module_handle;
403 // The CUDA driver isn't able to load empty PTX. It's okay if we skip loading
404 // in this case; if the module isn't loaded, all symbol lookups will fail,
405 // just as they should for an empty module.
406 if (!(executor->platform_kind() == se::PlatformKind::kCuda &&
407 module_spec.cuda_ptx_in_memory() == nullptr)) {
408 TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
409 }
410
411 for (const ConstantInfo& info : constants_) {
412 StatusOr<stream_executor::DeviceMemoryBase> global_status;
413 if (static_cast<bool>(module_handle)) {
414 global_status =
415 executor->GetUntypedSymbol(info.symbol_name, module_handle);
416 }
417
418 se::DeviceMemoryBase global;
419 if (static_cast<bool>(module_handle) && global_status.ok()) {
420 // The constant was defined in the PTX and has been allocated by the CUDA
421 // driver.
422 global = *global_status;
423 VLOG(3) << "Resolved global " << info.symbol_name << " to "
424 << global.opaque();
425
426 if (!info.content.empty()) {
427 // This means the constant did not have an initializer in the PTX and
428 // therefore must be initialized by XLA here.
429 stream->ThenMemcpy(&global, info.content.data(), info.content.size());
430 }
431 } else {
432 // The constant was not defined in the PTX and therefore must be both
433 // allocated and initialized by XLA here.
434 CHECK(!info.content.empty());
435
436 TF_ASSIGN_OR_RETURN(
437 auto shared, executor->CreateOrShareConstant(stream, info.content));
438 global = *shared;
439 VLOG(3) << "Allocated (or shared) global " << info.symbol_name << " at "
440 << global.opaque();
441 // XLA will continue to own this global at least until this executable is
442 // destroyed (longer if another, longer-lived executable shares the same
443 // constant).
444 shared_constants_.push_back(std::move(shared));
445 }
446
447 if (info.allocation_index != -1) {
448 InsertOrDie(&globals, info.allocation_index, global);
449 }
450 }
451
452 module_handles_.emplace(executor,
453 se::ScopedModuleHandle(executor, module_handle));
454 return &module_globals_.emplace(executor, std::move(globals)).first->second;
455 }
456
BufferForAllocation(VariantArguments arguments,const GpuExecutable::BufferAllocToDeviceMemoryMap * globals,const BufferAllocation & allocation,se::DeviceMemoryAllocator * const memory_allocator,int device_ordinal,int64_t arg_idx)457 StatusOr<se::DeviceMemoryBase> GpuExecutable::BufferForAllocation(
458 VariantArguments arguments,
459 const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
460 const BufferAllocation& allocation,
461 se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal,
462 int64_t arg_idx) {
463 if (allocation.is_thread_local()) {
464 return se::DeviceMemoryBase{};
465 } else if (allocation.is_entry_computation_parameter()) {
466 int64_t param_no = allocation.parameter_number();
467 se::DeviceMemoryBase registered_buffer = [&] {
468 if (auto unowned_shapedbuffers =
469 std::get_if<absl::Span<const ShapedBuffer* const>>(&arguments)) {
470 return (*unowned_shapedbuffers)[param_no]->buffers().element(
471 allocation.param_shape_index());
472 } else {
473 return std::get<absl::Span<ExecutionInput>>(arguments)[param_no]
474 .Buffer(allocation.param_shape_index())
475 .AsDeviceMemoryBase();
476 }
477 }();
478 if (registered_buffer.is_null() && registered_buffer.size() > 0) {
479 return FailedPrecondition(
480 "Cannot run XLA computation because pointer to (sub-)buffer at "
481 "index %s of parameter %d was null. All pointers to "
482 "(sub-)buffers must not be null, unless the (sub-)buffer has "
483 "zero elements.",
484 allocation.param_shape_index().ToString(), param_no);
485 }
486 return registered_buffer;
487 } else if (allocation.is_constant()) {
488 auto it = globals->find(arg_idx);
489 if (it == globals->end()) {
490 return se::DeviceMemoryBase();
491 }
492 return it->second;
493 } else {
494 // Allocate each allocation that might escape, or is the temp buffer.
495 CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer());
496 const int64_t buffer_size = allocation.size();
497 se::DeviceMemoryBase buffer_address;
498 if (buffer_size > 0) {
499 StatusOr<se::OwningDeviceMemory> buffer =
500 memory_allocator->Allocate(device_ordinal, buffer_size);
501 if (!buffer.ok()) {
502 return ResourceExhausted("%s\n%s\n", buffer.status().error_message(),
503 verbose_buffer_assignment_string_dumper_());
504 }
505 buffer_address = buffer->Release();
506 }
507 return buffer_address;
508 }
509 }
510
CheckAlignment(const BufferAllocation & allocation,se::DeviceMemoryBase buffer,int arg_idx)511 static Status CheckAlignment(const BufferAllocation& allocation,
512 se::DeviceMemoryBase buffer, int arg_idx) {
513 const int64_t expected_alignment = [&] {
514 if (allocation.is_entry_computation_parameter()) {
515 return kEntryParameterAlignBytes;
516 } else if (allocation.is_constant()) {
517 return kConstantBufferAlignBytes;
518 } else {
519 return kXlaAllocatedBufferAlignBytes;
520 }
521 }();
522 if (!buffer.is_null() &&
523 reinterpret_cast<uintptr_t>(buffer.opaque()) % expected_alignment != 0) {
524 return InternalError(
525 "Address of buffer %d must be a multiple of %x, but "
526 "was %p",
527 arg_idx, expected_alignment, buffer.opaque());
528 }
529 return OkStatus();
530 }
531
GenerateBufferAllocations(VariantArguments arguments,const GpuExecutable::BufferAllocToDeviceMemoryMap * globals,se::DeviceMemoryAllocator * const memory_allocator,int device_ordinal)532 StatusOr<BufferAllocations> GpuExecutable::GenerateBufferAllocations(
533 VariantArguments arguments,
534 const GpuExecutable::BufferAllocToDeviceMemoryMap* globals,
535 se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal) {
536 tensorflow::profiler::TraceMe hlo_module_activity(
537 [&] { return std::string("Build buffer allocations"); },
538 tensorflow::profiler::TraceMeLevel::kInfo);
539
540 const int64_t num_buffers = allocations_.size();
541 std::vector<se::DeviceMemoryBase> buffers;
542 buffers.reserve(num_buffers);
543 for (int64_t i = 0; i < num_buffers; ++i) {
544 const BufferAllocation& allocation = allocations_[i];
545 TF_ASSIGN_OR_RETURN(
546 se::DeviceMemoryBase buffer,
547 BufferForAllocation(arguments, globals, allocation, memory_allocator,
548 device_ordinal, i));
549 buffers.push_back(buffer);
550 TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffer, i));
551 }
552 return {{buffers, device_ordinal, memory_allocator}};
553 }
554
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)555 StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
556 const ServiceExecutableRunOptions* run_options,
557 std::vector<ExecutionInput> arguments,
558 HloExecutionProfile* hlo_execution_profile) {
559 return ExecuteAsyncOnStreamImpl(run_options, absl::MakeSpan(arguments));
560 }
561
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments,HloExecutionProfile * hlo_execution_profile)562 StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
563 const ServiceExecutableRunOptions* run_options,
564 absl::Span<const ShapedBuffer* const> arguments,
565 HloExecutionProfile* hlo_execution_profile) {
566 TF_ASSIGN_OR_RETURN(ExecutionOutput out,
567 ExecuteAsyncOnStreamImpl(run_options, arguments));
568 return out.ConsumeResult();
569 }
570
571 #if XLA_ENABLE_XLIR
ExecuteJitRt(const std::string & module_name,GpuExecutable::JitRtExecutable * jitrt_executable,const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,size_t num_allocations,bool block_host_until_done)572 static Status ExecuteJitRt(const std::string& module_name,
573 GpuExecutable::JitRtExecutable* jitrt_executable,
574 const ServiceExecutableRunOptions* run_options,
575 const BufferAllocations& buffer_allocations,
576 size_t num_allocations, bool block_host_until_done) {
577 uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
578
579 tensorflow::profiler::TraceMe hlo_module_activity(
580 [&] { return absl::StrCat(module_name, ":XLA GPU module"); },
581 tensorflow::profiler::TraceMeLevel::kInfo);
582
583 ScopedAnnotation annotation(
584 []() -> std::string { return "JitRtExecutable"; });
585
586 // TODO(ezhulenev): Here we rely on implementation details of passing memrefs
587 // to the compiled kernel. We should have a nicer API to do this, without
588 // creating a vector of temporary MemrefDesc for passing operands.
589
590 // Pack buffer allocations as executable arguments. It is guaranteed that
591 // compiled function will make a copy of all arguments and will write all
592 // results after the call to `Execute` completes, so it is safe to keep in on
593 // the stack.
594 runtime::Executable::CallFrame call_frame;
595
596 // Each buffer allocation pased as 1d memref to the compiled kernel:
597 // {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]}
598 size_t num_args_ptrs = 1 + num_allocations * 5;
599 call_frame.args.resize_for_overwrite(num_args_ptrs);
600
601 // Pass pointers to these constants as a memref offset and stride.
602 int64_t zero = 0;
603 int64_t one = 1;
604 void* offset = &zero;
605 void* stride = &one;
606
607 // Add a placeholder for the kernel context as the first argument.
608 call_frame.args[0] = nullptr;
609
610 // Storage for data pointers.
611 llvm::SmallVector<void*, 16> ptrs;
612 ptrs.resize_for_overwrite(num_allocations);
613
614 // Initialize arguments for the buffer operands.
615 for (unsigned i = 0; i < num_allocations; ++i) {
616 void* data = &(ptrs[i] = buffer_allocations.GetDeviceAddress(i).opaque());
617 void* size = const_cast<int64_t*>(&jitrt_executable->buffer_size(i));
618 unsigned idx = 1 + i * 5;
619 call_frame.args[idx + 0] = data;
620 call_frame.args[idx + 1] = data;
621 call_frame.args[idx + 2] = offset;
622 call_frame.args[idx + 3] = size;
623 call_frame.args[idx + 4] = stride;
624 }
625
626 // JitRt executables do not return any values.
627 runtime::NoResultConverter converter;
628
629 // Prepare options for executing JitRt program.
630 runtime::Executable::ExecuteOpts opts;
631
632 // We don't expect to see any async tasks in the JitRt executable.
633 opts.async_task_runner =
634 reinterpret_cast<runtime::AsyncTaskRunner*>(0XDEADBEEF);
635
636 // Get the async communications stream for async collectives.
637 int device_ordinal = run_options->stream()->parent()->device_ordinal();
638 StatusOr<StreamPool::Ptr> async_comms_stream =
639 run_options->BorrowStream(device_ordinal);
640
641 // Async collective support instantiated for each Gpu executable run, so that
642 // concurrent executions can run independenty using a separate set of events
643 // for communication.
644 JitRtAsyncCollectiveSupport async_collectives(
645 async_comms_stream.ok() ? async_comms_stream->get() : nullptr);
646
647 // Pass auxiliary data to the custom call handlers.
648 runtime::CustomCall::UserData user_data;
649 user_data.insert_all(
650 run_options, &jitrt_executable->debug_options(),
651 &jitrt_executable->kernels_cache(),
652 &jitrt_executable->gemm_configs_cache(), &jitrt_executable->collectives(),
653 async_collectives.async_comm_stream() ? &async_collectives : nullptr);
654 opts.custom_call_data = &user_data;
655
656 // Collect all emitted diagnostic messages.
657 runtime::DiagnosticEngine diagnostic_engine;
658 std::string diagnostic;
659 diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
660 llvm::raw_string_ostream(diagnostic) << d.str();
661 return mlir::success();
662 });
663
664 opts.diagnostic_engine = &diagnostic_engine;
665
666 // Execute with the prepared call frame.
667 runtime::Executable& executable = jitrt_executable->executable();
668 executable.Execute(call_frame, opts);
669
670 if (auto err = executable.ReturnResults(converter, &call_frame)) {
671 return InternalError(
672 "Failed to execute JitRt executable: %s.",
673 tfrt::StrCat(err,
674 diagnostic.empty() ? "" : tfrt::StrCat(": ", diagnostic)));
675 }
676
677 return MaybeSyncAndProfile(
678 run_options, start_micros,
679 block_host_until_done ? run_options->stream() : nullptr);
680 }
681 #endif // XLA_ENABLE_XLIR
682
ExecuteAsyncOnStreamImpl(const ServiceExecutableRunOptions * run_options,VariantArguments arguments)683 StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
684 const ServiceExecutableRunOptions* run_options,
685 VariantArguments arguments) {
686 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
687 "GpuExecutable::ExecuteAsyncOnStreamImpl(", module_name_, ")"));
688 se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator();
689 // Force synchronous execution if the allocator requires it.
690 const bool block_host_until_done =
691 !memory_allocator->AllowsAsynchronousDeallocation();
692
693 se::StreamExecutor* executor = run_options->stream()->parent();
694
695 // Lock the GPU with a shared lock so that we don't interfere with autotuning
696 // that may be running during JIT compilation while allowing multiple XLA
697 // computations to use the same GPU simultaneously.
698 absl::ReaderMutexLock gpu_lock(&GetGpuMutex(executor));
699
700 const GpuExecutable::BufferAllocToDeviceMemoryMap* globals;
701 {
702 tensorflow::profiler::TraceMe hlo_module_activity(
703 [&] { return std::string("Resolve constant globals"); },
704 tensorflow::profiler::TraceMeLevel::kInfo);
705
706 TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream()));
707 }
708
709 auto device_ordinal = executor->device_ordinal();
710 ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator,
711 device_ordinal);
712
713 TF_ASSIGN_OR_RETURN(
714 BufferAllocations buffer_allocations,
715 GenerateBufferAllocations(arguments, globals, memory_allocator,
716 device_ordinal));
717 VLOG(2) << buffer_allocations.ToString();
718 std::set<se::DeviceMemoryBase> buffers_in_result;
719
720 const bool is_entire_tuple_contents_aliased = [&] {
721 for (auto& p : result.MutableResult()->buffers().leaves()) {
722 if (!output_info_.contains(p.first)) {
723 continue;
724 }
725 const OutputInfo& output_info = output_info_.at(p.first);
726 if (!output_info.alias_config.has_value()) {
727 return false;
728 }
729 }
730 return true;
731 }();
732
733 for (auto& p : result.MutableResult()->buffers()) {
734 const ShapeIndex& index = p.first;
735 if (!output_info_.contains(index)) {
736 continue;
737 }
738 const OutputInfo& output_info = output_info_.at(index);
739 const BufferAllocation* allocation =
740 &allocations_[output_info.allocation_index];
741 se::DeviceMemoryBase& result_buffer = p.second;
742
743 VLOG(4) << "Looking at: allocation " << output_info.allocation_index
744 << " @ index: " << index.ToString();
745
746 if (output_info.alias_config) {
747 MaybeOwningDeviceMemory* maybe_owning_memory =
748 [&]() -> xla::MaybeOwningDeviceMemory* {
749 // ScopedBuffer is never an owned buffer.
750 if (auto* unowned_shapedbuffers =
751 std::get_if<absl::Span<const ShapedBuffer* const>>(
752 &arguments)) {
753 return nullptr;
754 } else {
755 auto unowned_execution_input =
756 std::get<absl::Span<ExecutionInput>>(arguments);
757 ExecutionInput& input =
758 unowned_execution_input[allocation->parameter_number()];
759 return input.MutableBuffer(allocation->param_shape_index());
760 }
761 }();
762 if (output_info.alias_config->must_alias() && maybe_owning_memory &&
763 !maybe_owning_memory->HasOwnership()) {
764 return InvalidArgument(
765 "An input was configured to be must-alias at "
766 "compile time but not donated at runtime: allocation %d",
767 output_info.allocation_index);
768 }
769 if (maybe_owning_memory && maybe_owning_memory->HasOwnership()) {
770 std::optional<tensorflow::se::OwningDeviceMemory> owning =
771 maybe_owning_memory->Release();
772 // If the caller passes the ownership of the device memory, reuse it
773 // as the output buffer. It is up to the caller whether or not to
774 // donate a buffer; the aliasing information describes which buffers
775 // may alias, not buffers that must alias.
776 se::DeviceMemoryBase argument_buffer = owning->Release();
777 *maybe_owning_memory = argument_buffer;
778 result_buffer = argument_buffer;
779 // The caller is giving us the
780 // input buffer, but in case of error from the execute call, we should
781 // not be releasing it as it contains valid data (for example, it is a
782 // parameter which the user wants us to alias, in a gradient update
783 // computation). So we store the index into the result in the aliased
784 // vector, which will be fed to the ExecutionOutput, which will use
785 // the indices to drop the addresses from its own ScopedShapedBuffer
786 // result, if the ExecutionOutput is not committed.
787 result.AddAliasedIndex(index);
788 } else if (!output_info.passthrough &&
789 !ShapeUtil::GetSubshape(output_shape_, index).IsTuple()) {
790 // The guard is above is not to insert copy-protection when aliasing
791 // pass-through params, as we do not need to write into the output
792 // buffer.
793 VLOG(3) << "Using copy-protection: aliasing is specified, but the "
794 "buffer is not donated; allocating a fresh buffer";
795 int64_t allocation_size =
796 ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(output_shape_, index));
797 StatusOr<se::OwningDeviceMemory> allocated_buffer =
798 memory_allocator->Allocate(device_ordinal, allocation_size);
799 if (!allocated_buffer.ok()) {
800 return ResourceExhausted("%s\n%s\n",
801 allocated_buffer.status().error_message(),
802 verbose_buffer_assignment_string_dumper_());
803 }
804 result_buffer = allocated_buffer->Release();
805 se::DeviceMemoryBase& aliased_buffer =
806 buffer_allocations.GetMutableDeviceAddress(
807 output_info.allocation_index);
808 CHECK_EQ(aliased_buffer.size(), result_buffer.size());
809 run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer,
810 aliased_buffer.size());
811 aliased_buffer = result_buffer;
812 }
813 }
814
815 if (result_buffer.is_null()) {
816 // The source instruction should have a non-parameter buffer
817 // assigned.
818 result_buffer =
819 buffer_allocations.GetDeviceAddress(output_info.allocation_index);
820
821 // If the entire tuple contents is aliased, the copy insertion will *not*
822 // materialize a new tuple, so we mark it as aliased as well.
823 if (is_entire_tuple_contents_aliased) {
824 result.AddAliasedIndex(index);
825 }
826 }
827 buffers_in_result.insert(result_buffer);
828 }
829
830 TF_RETURN_IF_ERROR(ExecuteThunksOrJitRt(run_options, buffer_allocations,
831 block_host_until_done));
832
833 // Free all temporary allocations.
834 TF_RETURN_IF_ERROR(
835 buffer_allocations.TearDown(buffers_in_result, allocations_));
836
837 // Free allocations for arguments.
838 if (auto args = std::get_if<absl::Span<ExecutionInput>>(&arguments)) {
839 MarkToBeReleasedArguments(*args, result);
840 }
841 return std::move(result);
842 }
843
ExecuteThunksOrJitRt(const ServiceExecutableRunOptions * run_options,const BufferAllocations & buffer_allocations,bool block_host_until_done)844 Status GpuExecutable::ExecuteThunksOrJitRt(
845 const ServiceExecutableRunOptions* run_options,
846 const BufferAllocations& buffer_allocations, bool block_host_until_done) {
847 TF_RETURN_IF_ERROR(
848 CheckCompatibilityWithServiceExecutableRunOptions(run_options));
849
850 if (thunks_) {
851 se::StreamExecutor* executor = run_options->stream()->parent();
852 for (const std::unique_ptr<Thunk>& thunk : *thunks_) {
853 TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
854 }
855 return ExecuteThunks(module_name_, *thunks_, run_options,
856 buffer_allocations, block_host_until_done);
857 }
858
859 #if XLA_ENABLE_XLIR
860 if (jitrt_executable_) {
861 return ExecuteJitRt(module_name_, jitrt_executable_, run_options,
862 buffer_allocations, allocations_.size(),
863 block_host_until_done);
864 }
865 #endif // XLA_ENABLE_XLIR
866
867 return FailedPrecondition("Expected XLA gpu executable is not supplied.");
868 }
869
SizeOfGeneratedCodeInBytes() const870 int64_t GpuExecutable::SizeOfGeneratedCodeInBytes() const {
871 // Non-empty PTX but empty cubin: compilation must have failed, return
872 // "unknown".
873 if (binary().empty() && !text_.empty()) {
874 return -1;
875 }
876 int64_t size = binary().size();
877 for (BufferAllocation::Index i = 0; i < allocations_.size(); ++i) {
878 const BufferAllocation& allocation = allocations_[i];
879 if (allocation.is_constant()) {
880 size += allocation.size();
881 }
882 }
883 return size;
884 }
885
SetUpMlirAllocation(mlir::func::FuncOp func,llvm::ArrayRef<int64_t> buffer_sizes,std::vector<BufferAllocation> * allocations,absl::flat_hash_map<ShapeIndex,GpuExecutable::OutputInfo> * output_info,Shape * output_shape,int buffer_param_offset)886 Status GpuExecutable::SetUpMlirAllocation(
887 mlir::func::FuncOp func, llvm::ArrayRef<int64_t> buffer_sizes,
888 std::vector<BufferAllocation>* allocations,
889 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
890 Shape* output_shape, int buffer_param_offset) {
891 for (int i = 0; i < buffer_sizes.size(); i++) {
892 allocations->emplace_back(i, buffer_sizes[i], 0);
893 }
894
895 for (int i = 0; i < func.getNumArguments(); i++) {
896 if (i < buffer_param_offset) {
897 continue;
898 }
899 const int buffer_index = i - buffer_param_offset;
900
901 if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
902 xla::ShapeIndex shape_index;
903 if (auto shape_index_attr =
904 func.getArgAttrOfType<mlir::DenseIntElementsAttr>(
905 i, "lmhlo.param_shape_index")) {
906 for (const llvm::APInt& element : shape_index_attr) {
907 shape_index.push_back(element.getSExtValue());
908 }
909 }
910 allocations->at(buffer_index)
911 .set_entry_computation_parameter(
912 param_attr.cast<mlir::IntegerAttr>().getInt(), shape_index,
913 static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
914 }
915 // TODO(timshen): this information is redundant. This is here only for
916 // smooth migration to LMHLO. Remove it.
917 if (func.getArgAttr(i, "lmhlo.constant_name")) {
918 allocations->at(buffer_index).set_constant(true);
919 }
920 if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
921 allocations->at(buffer_index).set_maybe_live_out(true);
922
923 // Reconstruct a shape index from output_index.
924 ShapeIndex shape_index;
925 for (const llvm::APInt& element :
926 output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
927 shape_index.push_back(element.getSExtValue());
928 }
929 auto& o = (*output_info)[shape_index];
930 o.allocation_index = buffer_index;
931 if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
932 HloInputOutputAliasConfig::AliasKind kind =
933 HloInputOutputAliasConfig::kMayAlias;
934 if (func.getArgAttr(i, "lmhlo.must_alias")) {
935 kind = HloInputOutputAliasConfig::kMustAlias;
936 }
937 o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
938 ShapeIndex{}, kind);
939 }
940 if (func.getArgument(i).use_empty()) {
941 o.passthrough = true;
942 }
943 }
944 }
945 // Expects result_xla_shape as a XLA shape in string form.
946 //
947 // The attribute is necessary, because GpuExecutable/ExecutionOutput supports
948 // tuples / tree-like shapes, while the LMHLO argument list loses the tree
949 // form.
950 //
951 // The string format is necessary since MLIR doesn't support XLA shape with
952 // dynamic_dimension.
953 //
954 // TODO(timshen): now this field is mandatory. Make it optional for
955 // non-GpuExecutable outputs.
956 TF_ASSIGN_OR_RETURN(
957 *output_shape,
958 ParseShape(func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
959 .getValue()
960 .str()));
961
962 return OkStatus();
963 }
964
965 StatusOr<absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>>
GetOutputInfo(const HloModule & hlo_module,const BufferAssignment & assignment)966 GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
967 const HloInstruction* root =
968 hlo_module.entry_computation()->root_instruction();
969
970 InstructionValueSet root_value_set =
971 assignment.dataflow_analysis().GetInstructionValueSet(root);
972
973 if (root_value_set.IsAmbiguous()) {
974 return Unimplemented("Points-to set of root instruction is ambiguous");
975 }
976
977 using OutputInfoMap =
978 absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
979 OutputInfoMap output;
980 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
981 root->shape(),
982 [&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> Status {
983 const auto& sources = root_value_set.element(index);
984 // The points-to set is unambiguous so the set should be a
985 // singleton. That is, we know exactly which instruction
986 // produced the array at this element.
987 CHECK_EQ(1, sources.values().size());
988 HloInstruction* src_hlo = sources.values()[0]->instruction();
989
990 GpuExecutable::OutputInfo& info = output[index];
991 info.passthrough = src_hlo->opcode() == HloOpcode::kParameter;
992 TF_ASSIGN_OR_RETURN(
993 const BufferAllocation::Slice slice,
994 assignment.GetUniqueSlice(src_hlo, sources.values()[0]->index()));
995 CHECK_EQ(slice.offset(), 0) << "Parameter should get its own slice";
996 info.allocation_index = slice.index();
997
998 output[index].alias_config =
999 hlo_module.input_output_alias_config().GetAliasedParameter(index);
1000
1001 return OkStatus();
1002 }));
1003 return output;
1004 }
1005
GpuExecutable(std::shared_ptr<HloModule> hlo_module,GpuVersion gpu_version,xla::EntryFunctionAttributes entry_func_attrs,absl::string_view module_name,Shape xla_output_shape,std::vector<BufferAllocation> allocations,absl::flat_hash_map<ShapeIndex,OutputInfo> output_info,JitRtExecutable * jitrt_executable)1006 GpuExecutable::GpuExecutable(
1007 std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
1008 xla::EntryFunctionAttributes entry_func_attrs,
1009 absl::string_view module_name, Shape xla_output_shape,
1010 std::vector<BufferAllocation> allocations,
1011 absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
1012 JitRtExecutable* jitrt_executable)
1013 : Executable(std::move(hlo_module)),
1014 gpu_version_(gpu_version),
1015 entry_func_attrs_(entry_func_attrs),
1016 module_name_(module_name),
1017 output_shape_(xla_output_shape),
1018 allocations_(std::move(allocations)),
1019 output_info_(std::move(output_info)),
1020 jitrt_executable_(jitrt_executable) {
1021 XlaDebugInfoManager::Get()->RegisterModule(
1022 module().unique_id(), shared_module(), debug_buffer_assignment_);
1023 }
1024
LoadFromObjFile(std::shared_ptr<HloModule> hlo_module,absl::string_view obj_file,absl::string_view mlir_module,xla::EntryFunctionAttributes entry_func_attrs,DebugOptions debug_options,GpuVersion gpu_version,se::StreamExecutor * executor)1025 StatusOr<std::unique_ptr<Executable>> GpuExecutable::LoadFromObjFile(
1026 std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
1027 absl::string_view mlir_module,
1028 xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
1029 GpuVersion gpu_version, se::StreamExecutor* executor) {
1030 #if XLA_ENABLE_XLIR
1031 // Load MLIR module behind the compiled object file to recover XLA allocations
1032 // and output info details. Also recover buffer sizes from the entrypoint
1033 // function signature.
1034 mlir::MLIRContext context;
1035
1036 mlir::DialectRegistry registry;
1037 tfrt::RegisterTFRTDialects(registry);
1038 tfrt::RegisterTFRTCompiledDialects(registry);
1039 context.appendDialectRegistry(registry);
1040
1041 auto module = mlir::parseSourceString<mlir::ModuleOp>(mlir_module, &context);
1042 if (!module) return InternalError("Failed to parse AOT compiled module");
1043
1044 // Get the XLA module entrypoint function.
1045 auto func = mlir::cast<mlir::func::FuncOp>(
1046 module->lookupSymbol(hlo_module->entry_computation()->name()));
1047
1048 // Get the buffer sizes from the entrypoint function signature.
1049 std::vector<int64_t> buffer_sizes;
1050 buffer_sizes.reserve(func.getNumArguments());
1051 for (auto type : func.getArgumentTypes()) {
1052 auto memref = type.dyn_cast<mlir::MemRefType>();
1053 if (!memref || !memref.hasStaticShape() || memref.getRank() != 1)
1054 return InternalError("Illegal entrypoint argument type: %s",
1055 tfrt::StrCat(type));
1056 buffer_sizes.push_back(memref.getDimSize(0));
1057 }
1058
1059 // Infer XLA allocations and output info from the MLIR module.
1060 std::vector<BufferAllocation> allocations;
1061 absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
1062 Shape result_xla_shape;
1063 TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations,
1064 &output_info, &result_xla_shape,
1065 /*buffer_param_offset=*/0));
1066
1067 // Create a named buffer from compiled object file.
1068 llvm::StringRef data(obj_file.data(), obj_file.size());
1069 auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
1070
1071 // Create a JitRt function signature (all arguments passed as 1d memrefs).
1072 llvm::SmallVector<std::unique_ptr<runtime::Type>> args;
1073 llvm::SmallVector<std::unique_ptr<runtime::Type>> rt_args;
1074 rt_args.push_back(std::make_unique<runtime::KernelContextOperandType>());
1075
1076 for (int64_t size : buffer_sizes) {
1077 auto i8 = tfrt::DType::I8;
1078 args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
1079 rt_args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
1080 }
1081
1082 runtime::FunctionType signature(std::move(args), /*results=*/{});
1083 runtime::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
1084
1085 auto symbol_map =
1086 runtime::ToSymbolsBinding(JitRtGpuCustomCalls(), PopulateXlaTypeIdNames);
1087
1088 // Load JitRt executable from an object file, and link it with Gpu runtime
1089 // intrinsics implementing Gpu custom calls.
1090 auto executable = runtime::Executable::LoadFromObjFile(
1091 hlo_module->name(), std::move(buffer),
1092 hlo_module->entry_computation()->name(), std::move(signature),
1093 std::move(rt_signature), symbol_map);
1094 if (auto err = executable.takeError())
1095 return InternalError("Failed to load JitRt executable: %s",
1096 tfrt::StrCat(err));
1097
1098 // Move runtime::Executable ownership to the JitRtExecutable.
1099 TF_ASSIGN_OR_RETURN(
1100 JitRtExecutable * jitrt_executable,
1101 JitRtExecutable::Create(buffer_sizes, std::move(*executable),
1102 std::move(debug_options)));
1103
1104 // Construct GpuExecutable for the loaded JitRt executable.
1105 std::string name = hlo_module->name();
1106 return std::unique_ptr<Executable>(
1107 new GpuExecutable(std::move(hlo_module), gpu_version, entry_func_attrs,
1108 name, result_xla_shape, std::move(allocations),
1109 std::move(output_info), jitrt_executable));
1110
1111 #else // XLA_ENABLE_XLIR
1112 return FailedPrecondition("Not built with XLA_ENABLE_XLIR");
1113 #endif // XLA_ENABLE_XLIR
1114 }
1115 } // namespace gpu
1116 } // namespace xla
1117