1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "mlir/Dialect/Async/IR/AsyncTypes.h"
23 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
24 #include "mlir/ExecutionEngine/AsyncRuntime.h"
25 #include "tensorflow/compiler/jit/flags.h"
26 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
28 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h"
29 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels_registration.h"
30 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
31 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_query_of_death.h"
32 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
33 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
34 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
35 #include "tensorflow/compiler/xla/runtime/arguments.h"
36 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
37 #include "tensorflow/compiler/xla/runtime/executable.h"
38 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
39 #include "tensorflow/compiler/xla/runtime/types.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/platform/dynamic_annotations.h"
43 #include "tensorflow/core/platform/threadpool.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
46 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
47 #include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
48 #include "tfrt/jitrt/results.h" // from @tf_runtime
49 #include "tfrt/dtype/dtype.h" // from @tf_runtime
50 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
51 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
52 #include "tfrt/host_context/chain.h" // from @tf_runtime
53 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
54 #include "tfrt/host_context/host_buffer.h" // from @tf_runtime
55 #include "tfrt/host_context/host_context.h" // from @tf_runtime
56 #include "tfrt/host_context/kernel_registry.h" // from @tf_runtime
57 #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
58 #include "tfrt/host_context/shared_context.h" // from @tf_runtime
59 #include "tfrt/support/error_util.h" // from @tf_runtime
60 #include "tfrt/support/forward_decls.h" // from @tf_runtime
61 #include "tfrt/support/rc_array.h" // from @tf_runtime
62 #include "tfrt/support/string_util.h" // from @tf_runtime
63 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
64 #include "tfrt/tensor/tensor_shape.h" // from @tf_runtime
65
66 namespace tensorflow {
67 namespace {
68
69 using ::std::any_cast;
70
71 using ::llvm::cast;
72 using ::llvm::Expected;
73 using ::llvm::MutableArrayRef;
74 using ::llvm::None;
75 using ::llvm::Optional;
76
77 using ::tfrt::Argument;
78 using ::tfrt::ArrayRef;
79 using ::tfrt::AsyncValue;
80 using ::tfrt::AsyncValuePtr;
81 using ::tfrt::AsyncValueRef;
82 using ::tfrt::Attribute;
83 using ::tfrt::Chain;
84 using ::tfrt::CompilationUnitAttribute;
85 using ::tfrt::DecodedDiagnostic;
86 using ::tfrt::DType;
87 using ::tfrt::EmitErrorAsync;
88 using ::tfrt::ExecutionContext;
89 using ::tfrt::HostContext;
90 using ::tfrt::Index;
91 using ::tfrt::IndirectAsyncValue;
92 using ::tfrt::KernelRegistry;
93 using ::tfrt::MakeAvailableAsyncValueRef;
94 using ::tfrt::MakeConstructedAsyncValueRef;
95 using ::tfrt::MakeErrorAsyncValueRef;
96 using ::tfrt::MakeStringError;
97 using ::tfrt::RCArray;
98 using ::tfrt::RCReference;
99 using ::tfrt::RemainingResults;
100 using ::tfrt::RepeatedArguments;
101 using ::tfrt::RequestContext;
102 using ::tfrt::SharedContext;
103 using ::tfrt::StrCat;
104 using ::tfrt::StringAttribute;
105
106 using ::tfrt::jitrt::CompilationPipelineOptions;
107 using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
108 using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
109 using ::tfrt::jitrt::ReturnErrors;
110 using ::tfrt::jitrt::ReturnStridedMemref;
111 using ::tfrt::jitrt::ReturnValueConversion;
112 using ::tfrt::jitrt::StaticRemainingResultsConverter;
113
114 using ::xla::runtime::ArgumentConstraint;
115 using ::xla::runtime::ArgumentsRef;
116 using ::xla::runtime::AsyncValuesCache;
117 using ::xla::runtime::EigenThreadPoolAsyncTaskRunner;
118 using ::xla::runtime::Executable;
119 using ::xla::runtime::JitExecutable;
120 using ::xla::runtime::MemrefDesc;
121 using ::xla::runtime::SpecializationListener;
122
123 using ::tensorflow::profiler::TraceMe;
124 using ::tensorflow::profiler::TraceMeEncode;
125 using ::tensorflow::tfd::KernelFallbackCompatRequestState;
126 using ::tensorflow::tfrt_stub::FallbackTensor;
127 using ::tensorflow::thread::ThreadPool;
128
129 template <typename T>
130 using KernelArgument = ::tfrt::Argument<T>;
131
132 using JitExecutableCache = AsyncValuesCache<size_t, JitExecutable>;
133
134 // -------------------------------------------------------------------------- //
135 // Dedicated thread pool for running compilation tasks.
136 // -------------------------------------------------------------------------- //
137
138 class CompilationThreadPool : public SharedContext {
139 public:
CompilationThreadPool(HostContext * host)140 explicit CompilationThreadPool(HostContext* host) { Reset(); }
141
Get(HostContext * host)142 static CompilationThreadPool& Get(HostContext* host) {
143 return host->GetOrCreateSharedContext<CompilationThreadPool>();
144 }
145
146 template <typename Task>
Schedule(Task && task)147 void Schedule(Task&& task) {
148 // Because compilation tasks can capture move only types, and Tensorflow
149 // thread pool requires std::function tasks, we have to do manual memory
150 // management here.
151 auto ptr = std::make_unique<Task>(std::forward<Task>(task));
152 thread_pool_->Schedule([ptr = ptr.release()]() {
153 (*ptr)();
154 delete ptr;
155 });
156 }
157
158 // This is an unsafe function intended only for use in tests. It is undefined
159 // behavior to call it concurrently with `Schedule`.
Reset()160 void Reset() {
161 thread_pool_ = std::make_unique<ThreadPool>(
162 Env::Default(), "tf-jitrt-compiler", /*num_threads=*/32);
163 }
164
165 private:
166 std::unique_ptr<ThreadPool> thread_pool_;
167 };
168
169 // -------------------------------------------------------------------------- //
170 // JIT compiled kernels use Eigen ThreadPool managed by the kernel fallback as
171 // an async runtime worker threads.
172 // -------------------------------------------------------------------------- //
173
GetWorkerThreads(const ExecutionContext & exec_ctx)174 static Expected<Eigen::ThreadPoolInterface*> GetWorkerThreads(
175 const ExecutionContext& exec_ctx) {
176 RequestContext* req_ctx = exec_ctx.request_ctx();
177
178 auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
179 if (LLVM_UNLIKELY(!fallback))
180 return MakeStringError("fallback request state was not found");
181
182 // Return user provided intra op thread pool if it is available.
183 if (LLVM_LIKELY(fallback->intra_op_threadpool()))
184 return fallback->intra_op_threadpool();
185
186 // Otherwise find the default CPU device in the device manager.
187 Device* host_cpu = fallback->device_manager().HostCPU();
188 assert(host_cpu && "fallback state must have a valid host cpu device");
189
190 const Eigen::ThreadPoolDevice* eigen = host_cpu->eigen_cpu_device();
191 assert(eigen && "host cpu device must have a valid Eigen thread pool device");
192
193 return eigen->getPool();
194 }
195
196 // -------------------------------------------------------------------------- //
197 // Compile compilation unit attribute to an executable result.
198 // -------------------------------------------------------------------------- //
199
200 // Options for the `tf-jitrt-pipeline`. We do not use MLIR pass options directly
201 // because they are not copyable or movable, and we need to pass them cheaply
202 // across the async compilation tasks boundary.
203 struct TfJitRtPipelineOpts {
204 bool vectorize;
205 bool legalize_i1_tensors;
206 };
207
208 // Prints memref descriptor as a tensor type: tensor<NxMxf32>.
AsTensorType(const MemrefDesc & desc)209 static std::string AsTensorType(const MemrefDesc& desc) {
210 std::string str;
211 llvm::raw_string_ostream os(str);
212
213 os << "tensor<";
214 for (size_t size : desc.sizes()) os << size << "x";
215 os << desc.dtype();
216 os << ">";
217
218 return str;
219 }
220
221 // Print memref descriptor content to trace value specializations.
AsTensorContent(const MemrefDesc & desc)222 static std::string AsTensorContent(const MemrefDesc& desc) {
223 std::string str;
224 llvm::raw_string_ostream os(str);
225
226 auto print_0d = [&](auto type_tag) {
227 os << desc.dtype() << ": "
228 << *static_cast<decltype(type_tag)*>(desc.data());
229 };
230
231 auto print_1d = [&](auto type_tag) {
232 os << desc.dtype() << ": [";
233 for (size_t i = 0; i < desc.size(0); ++i) {
234 if (i != 0) os << ",";
235 os << static_cast<decltype(type_tag)*>(desc.data())[i];
236 }
237 os << "]";
238 };
239
240 auto type_dispatch = [&](auto functor) {
241 switch (desc.dtype()) {
242 case DType::I32:
243 functor(int32_t{});
244 break;
245 case DType::I64:
246 functor(int64_t{});
247 break;
248 default:
249 os << "<unsupported dtype " << desc.dtype() << ">";
250 }
251 };
252
253 size_t rank = desc.rank();
254
255 switch (rank) {
256 case 0:
257 type_dispatch(print_0d);
258 break;
259 case 1:
260 type_dispatch(print_1d);
261 break;
262 default:
263 os << "<unsupported rank " << desc.rank() << ">";
264 }
265
266 return str;
267 }
268
269 // Gets the session name from the fallback request state.
GetSessionName(RequestContext * req_ctx)270 static const std::string GetSessionName(RequestContext* req_ctx) {
271 auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
272 if (!fallback) return "<unknown>";
273
274 return fallback->session_metadata().name();
275 }
276
CompileImpl(const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,const Optional<TfJitRtPipelineOpts> & opts=None)277 static Expected<AsyncValuePtr<JitExecutable>> CompileImpl(
278 const CompilationUnitAttribute& kernel, const ExecutionContext& exec_ctx,
279 const Optional<TfJitRtPipelineOpts>& opts = None) {
280 // Request context must be initialized with the tf_jitrt state.
281 auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
282 if (LLVM_UNLIKELY(!state))
283 return MakeStringError("tf_jitrt state not found in the request context");
284
285 // We rely on the unique `id` provided by the CompilationUnitAttribute to look
286 // up the JitExecutable in the cache. This id is guaranteed to be unique
287 // within a Bef file. Currently we rely on the fact that the SavedModel
288 // never unloads a Bef file, and there is a 1-to-1 relationship between the
289 // ResourceContext and the SavedModel.
290 //
291 // TODO(b/206081322): Different compilation options should create unique
292 // compiled kernel cache keys.
293 size_t key = kernel.id();
294
295 JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
296
297 // Maybe return JitExecutable from the cache.
298 auto cached = jit_executable_cache->Find(key);
299 if (LLVM_LIKELY(cached)) return cached;
300
301 // Get the worker threads from the execution context. Do this before
302 // allocating an async value to make sure that we can try to instantiate the
303 // executable.
304 Expected<Eigen::ThreadPoolInterface*> worker_threads =
305 GetWorkerThreads(exec_ctx);
306 if (auto err = worker_threads.takeError()) return std::move(err);
307
308 // Allocate a placeholder for the compiled JitExecutable.
309 JitExecutableCache::Entry entry = jit_executable_cache->Allocate(key);
310
311 // We lost the race; some other invocation will do the compilation.
312 if (!entry.allocated) return entry.ptr;
313
314 // Given that compilation happens asynchronously, passing (or capturing) these
315 // by value prevents use-after-free errors.
316 struct KernelInfo {
317 intptr_t id;
318 std::string entrypoint;
319 std::string name;
320 std::string serialized_operation;
321 } kernel_info;
322
323 // We only support functions nested in top level compiled module.
324 if (kernel.nested_symbols().size() != 1)
325 return MakeStringError(
326 "kernel function has to be defined in a top-level module");
327
328 // TODO(ecg): use designed initializers + const when C++20 is adopted.
329 kernel_info.id = kernel.id();
330 kernel_info.entrypoint = kernel.nested_symbols()[0];
331 kernel_info.name = kernel.root_symbol();
332 kernel_info.serialized_operation = kernel.serialized_operation();
333
334 // Compilation (specialized executable compilation) events should be rare, so
335 // we can afford to do detailed tracing for every compilation. If compilation
336 // events happen too often, it is a much larger problem than the excessive
337 // tracing.
338
339 // Custom runner for compiling specializations that schedules compilation task
340 // into the dedicated thread pool and adds tracing.
341 auto runner = [kernel_info](size_t specialization,
342 ArrayRef<ArgumentConstraint> constraints,
343 ArgumentsRef arguments,
344 JitExecutable::CompilationTask compile,
345 JitExecutable::UserData user_data) {
346 assert(arguments.size() == constraints.size());
347
348 // Get the context of the request that triggered specialization compilation.
349 RequestContext* req_ctx = any_cast<RequestContext*>(user_data);
350 HostContext* host = req_ctx->host();
351
352 // Prepare arguments for the compilation tracing in the caller thread,
353 // because operands lifetime is shorter than the compilation task.
354 using SpecializationArg = std::pair<std::string, std::string>;
355 llvm::SmallVector<SpecializationArg> args;
356 args.reserve(arguments.size());
357
358 // Trace types of all operands of the specialization.
359 for (size_t i = 0; i < arguments.size(); ++i)
360 args.emplace_back(StrCat("%arg", i, " type"),
361 AsTensorType(cast<MemrefDesc>(arguments[i])));
362
363 // Trace content of all operands that require value specializations.
364 for (size_t i = 0; i < constraints.size(); ++i) {
365 if (constraints[i] != ArgumentConstraint::kValue) continue;
366 args.emplace_back(StrCat("%arg", i, " value"),
367 AsTensorContent(cast<MemrefDesc>(arguments[i])));
368 }
369
370 // Schedule specialization compilation task into the dedicated thread pool.
371 CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
372
373 thread_pool.Schedule(
374 [kernel_info, specialization, request_id = req_ctx->id(),
375 session_name = GetSessionName(req_ctx), compile = std::move(compile),
376 args = std::move(args)]() mutable {
377 TraceMe trace_me([&] {
378 return TraceMeEncode("tf_jitrt.CompileSpecialization",
379 {{"id", request_id},
380 {"kernel_id", kernel_info.id},
381 {"executable", kernel_info.name},
382 {"specialization", specialization}});
383 });
384
385 for (SpecializationArg& arg : args) {
386 trace_me.AppendMetadata([&] {
387 return TraceMeEncode({{arg.first, arg.second}});
388 });
389 }
390
391 trace_me.AppendMetadata([&] {
392 return TraceMeEncode({{"src", kernel_info.serialized_operation}});
393 });
394
395 auto compile_start_time = absl::Now();
396 LOG(INFO) << "Started JitExecutable specialization compilation for "
397 << kernel_info.name << " (" << session_name << ")";
398 compile();
399 auto compile_duration = absl::Now() - compile_start_time;
400
401 LOG(INFO) << "JitExecutable specialization compilation for "
402 << kernel_info.name << " took "
403 << absl::ToInt64Milliseconds(compile_duration) << " ms ("
404 << session_name << ")";
405
406 if (compile_duration > absl::Seconds(1))
407 LOG(INFO) << "Expensive JitExecutable specialization compilation ("
408 << absl::ToInt64Milliseconds(compile_duration)
409 << " ms):\n"
410 << kernel_info.serialized_operation;
411
412 RecordCompileTime(session_name, kernel_info.name, specialization,
413 compile_duration);
414 });
415 };
416
417 HostContext* host = exec_ctx.host();
418 RequestContext* req_ctx = exec_ctx.request_ctx();
419
420 // Compile kernel asynchronously in the compilation thread pool.
421 CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
422
423 thread_pool.Schedule([kernel_info, runner, workers = *worker_threads,
424 ref = entry.ptr.CopyRef(), request_id = req_ctx->id(),
425 session_name = GetSessionName(req_ctx),
426 tf_jitrt_opts = opts]() {
427 TraceMe trace_me([&] {
428 return TraceMeEncode("tf_jitrt.CompileDefault",
429 {{"id", request_id},
430 {"kernel_id", kernel_info.id},
431 {"executable", kernel_info.name},
432 {"src", kernel_info.serialized_operation}});
433 });
434
435 // Options for the default JitRt compilation pipeline (lowering to LLVM).
436 CompilationPipelineOptions copts;
437 copts.alignment = EIGEN_MAX_ALIGN_BYTES; // Eigen included by tensor.h
438 copts.num_worker_threads = workers->NumThreads();
439 copts.cost_driven_async_parallel_for =
440 GetJitRtFlags().cost_driven_async_parallel_for;
441
442 // Options for the JitRt JitExecutable compilation.
443 JitExecutable::Options opts;
444 opts.specialization = GetJitRtFlags().always_specialize
445 ? JitExecutable::Specialization::kAlways
446 : JitExecutable::Specialization::kEnabled;
447
448 // Register dialects and interfaces required for the compilation pipeline.
449 opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
450 mlir::RegisterAllTensorFlowDialects(registry);
451 RegisterDefaultJitRtDialects(registry);
452 };
453
454 // Register a custom pipeline for lowering from Tensorflow dialect to LLVM.
455 opts.compiler.create_compilation_pipeline = [=](mlir::PassManager& pm) {
456 if (GetJitRtFlags().enable_crash_reproducer)
457 SetCrashReproducer(pm, kCrashReproducerStdErr);
458
459 TfJitRtPipelineOptions opts;
460 if (tf_jitrt_opts) {
461 opts.vectorize = tf_jitrt_opts->vectorize;
462 opts.legalize_i1_tensors = tf_jitrt_opts->legalize_i1_tensors;
463 } else {
464 opts.vectorize = GetJitRtFlags().vectorize;
465 }
466
467 // Lower from Tensorflow to Linalg on buffers.
468 CreateTfJitRtPipeline(pm, opts);
469
470 // Use default JitRt compilation pipeline to lower to LLVM.
471 CreateDefaultJitRtCompilationPipeline(pm, copts);
472 };
473
474 // Register a custom pipeline to propagate specialization information.
475 opts.compiler.create_specialization_pipeline = [=](mlir::PassManager& pm) {
476 if (GetJitRtFlags().enable_crash_reproducer)
477 SetCrashReproducer(pm, kCrashReproducerStdErr);
478 CreateJitRtSpecializationPipeline(pm);
479 };
480
481 // When lowering Tensorflow functions to JitRt we convert all input and
482 // result tensors to memrefs, and add a kernel context input.
483 opts.compiler.calling_convention = xla::runtime::DefaultCallingConvention(
484 mlir::bufferization::BufferizeTypeConverter());
485
486 // Instantiate new JitExecutable from the MLIR source.
487 auto compile_start_time = absl::Now();
488 LOG(INFO) << "Started JitExecutable instantiation compilation for "
489 << kernel_info.name << " (" << session_name << ")";
490 Expected<JitExecutable> jit_executable = JitExecutable::Instantiate(
491 kernel_info.serialized_operation, kernel_info.entrypoint,
492 std::move(opts), session_name, runner);
493 auto compile_duration = absl::Now() - compile_start_time;
494
495 LOG(INFO) << "JitExecutable instantiation for " << kernel_info.name
496 << " took " << absl::ToInt64Milliseconds(compile_duration)
497 << " ms (" << session_name << ")";
498
499 if (compile_duration > absl::Seconds(1))
500 LOG(INFO) << "Expensive JitExecutable instantiation ("
501 << absl::ToInt64Milliseconds(compile_duration) << " ms):\n"
502 << kernel_info.serialized_operation;
503
504 RecordCompileTime(session_name, kernel_info.name, std::nullopt,
505 compile_duration);
506
507 // Set the entry async value state to error or concrete.
508 if (auto err = jit_executable.takeError())
509 ref.SetError(std::move(err));
510 else
511 ref.emplace(std::move(*jit_executable));
512 });
513
514 return entry.ptr;
515 }
516
517 // -------------------------------------------------------------------------- //
518 // TFRT kernel function definition for tf_jitrt.fallback.compile operation.
519 // -------------------------------------------------------------------------- //
520
521 // Compiles kernel into the JitExecutable and updates JitExecutableCache.
Compile(StringAttribute device,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)522 static AsyncValueRef<Chain> Compile(StringAttribute device,
523 CompilationUnitAttribute kernel,
524 const ExecutionContext& exec_ctx) {
525 // Trigger kernel compilation, that will update the JitExecutableCache.
526 Expected<AsyncValuePtr<JitExecutable>> executable =
527 CompileImpl(kernel, exec_ctx);
528
529 // Return error if can't schedule the compilation task.
530 if (auto err = executable.takeError())
531 return MakeErrorAsyncValueRef(StrCat(err));
532
533 // Mark chain available once we compile the default executable.
534 auto chain = MakeConstructedAsyncValueRef<Chain>();
535 executable->AndThen([chain]() { chain.SetStateConcrete(); });
536
537 return chain;
538 }
539
540 // -------------------------------------------------------------------------- //
541 // TFRT kernel function definition for tf_jitrt.test.wait_for_compilation.
542 // -------------------------------------------------------------------------- //
543
WaitForCompilation(KernelArgument<Chain> chain,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)544 static AsyncValueRef<Chain> WaitForCompilation(
545 KernelArgument<Chain> chain, CompilationUnitAttribute kernel,
546 const ExecutionContext& exec_ctx) {
547 // Request context must be initialized with the tf_jitrt state.
548 auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
549 if (!state)
550 return EmitErrorAsync(exec_ctx,
551 "tf_jitrt state not found in the request context");
552
553 // Wait for the completion of all compilation tasks.
554 JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
555 if (auto cached = jit_executable_cache->Find(kernel.id()))
556 return cached->AllExecutablesCompiled();
557
558 return MakeAvailableAsyncValueRef<Chain>();
559 }
560
561 // -------------------------------------------------------------------------- //
562 // TFRT kernel function for tf_jitrt.test.reset_compilation_thread_pool.
563 // -------------------------------------------------------------------------- //
564
ResetCompilationThreadPool(KernelArgument<Chain> chain,const ExecutionContext & exec_ctx)565 static AsyncValueRef<Chain> ResetCompilationThreadPool(
566 KernelArgument<Chain> chain, const ExecutionContext& exec_ctx) {
567 // Make sure that we reset the compilation thread pool only from a thread pool
568 // (concurrent work queue) managed by the HostContext.
569 return EnqueueWork(exec_ctx, [host = exec_ctx.host()]() -> Chain {
570 CompilationThreadPool::Get(host).Reset();
571 return {};
572 });
573 }
574
575 // -------------------------------------------------------------------------- //
576 // Execute compiled JitRt kernels with Fallback Runtime interop.
577 // -------------------------------------------------------------------------- //
578
579 using ReturnTensorflowTensor =
580 ReturnValueConversion<TensorflowConversionContext,
581 ReturnStridedMemref<ConvertTensor>>;
582
583 using TensorflowResultConverter =
584 StaticRemainingResultsConverter<TensorflowConversionContext,
585 ReturnTensorflowTensor>;
586
ConvertTensorToMemrefDesc(const tensorflow::Tensor & tensor)587 static MemrefDesc ConvertTensorToMemrefDesc(const tensorflow::Tensor& tensor) {
588 // Fills memref sizes and strides with a tensor shape;
589 auto fill_desc = [&](MutableArrayRef<Index> sizes,
590 MutableArrayRef<Index> strides) {
591 int64_t multiplier = 1;
592 for (int i = tensor.dims() - 1; i >= 0; --i) {
593 int64_t dim_size = tensor.dim_size(i);
594 sizes[i] = dim_size;
595 strides[i] = multiplier;
596 multiplier *= dim_size;
597 }
598 };
599
600 return MemrefDesc(tensor.dims(), tfd::GetTfrtDtype(tensor.dtype()),
601 const_cast<void*>(tensor.data()), 0, fill_desc);
602 }
603
ConvertTensorOperandsToMemrefDesc(RepeatedArguments<FallbackTensor> operands)604 static std::vector<MemrefDesc> ConvertTensorOperandsToMemrefDesc(
605 RepeatedArguments<FallbackTensor> operands) {
606 std::vector<MemrefDesc> memrefs;
607 memrefs.reserve(operands.size());
608 for (FallbackTensor& operand : operands)
609 memrefs.emplace_back(ConvertTensorToMemrefDesc(operand.tensor()));
610 return memrefs;
611 }
612
613 struct DebugListener : public SpecializationListener {
notifyModuleSpecializedtensorflow::__anon5bd986ca0111::DebugListener614 void notifyModuleSpecialized(
615 ArrayRef<mlir::Type> operands,
616 ArrayRef<mlir::DictionaryAttr> attrs) const override {
617 std::string message;
618 llvm::raw_string_ostream os(message);
619 os << "Specialized operands:\n";
620 for (auto& tuple : llvm::enumerate(llvm::zip(operands, attrs))) {
621 mlir::Type type = std::get<0>(tuple.value());
622 mlir::Attribute attr = std::get<1>(tuple.value());
623 os << "%arg" << tuple.index() << ": " << type << " " << attr << "\n";
624 }
625 printf("%s", message.c_str());
626 fflush(stdout);
627 }
628
notifyValueSpecializedtensorflow::__anon5bd986ca0111::DebugListener629 void notifyValueSpecialized(unsigned index, mlir::Type type,
630 mlir::Attribute value) const override {
631 std::string message;
632 llvm::raw_string_ostream(message) << "%arg" << index << " "
633 << "value specialized: " << value << "\n";
634 printf("%s", message.c_str());
635 fflush(stdout);
636 }
637 };
638
639 // Emits diagnostics for the kernel invocation and returns error for all
640 // remaining results.
641 template <typename Error>
ReturnErrors(RemainingResults results,Error error,const ExecutionContext & exec_ctx)642 static void ReturnErrors(RemainingResults results, Error error,
643 const ExecutionContext& exec_ctx) {
644 EmitError(exec_ctx, StrCat(error));
645 ReturnErrors(results, std::move(error));
646 }
647
ExecuteImpl(Executable & executable,ArrayRef<MemrefDesc> memrefs,RepeatedArguments<FallbackTensor> operands,RemainingResults results,const ExecutionContext & exec_ctx)648 static void ExecuteImpl(Executable& executable, ArrayRef<MemrefDesc> memrefs,
649 RepeatedArguments<FallbackTensor> operands,
650 RemainingResults results,
651 const ExecutionContext& exec_ctx) {
652 // Bind execution trace to the request context.
653 TraceMe trace_me([&] {
654 int64_t id = exec_ctx.request_ctx()->id();
655 absl::string_view name(executable.name().data(), executable.name().size());
656 return TraceMeEncode(
657 "tf_jitrt.Execute",
658 {{"id", id},
659 {"executable", name},
660 {"specialization", !executable.specialization().has_value()
661 ? "default"
662 : std::to_string(*executable.specialization())},
663 {"time_to_compile_ms", executable.time_to_compile().count()}});
664 });
665
666 // TODO(ezhulenev): Conversion context and async task runner might not outlive
667 // the execution of all async tasks, and should be kept alive until all tasks
668 // are completed, which will require heap allocation(s).
669 assert(!executable.IsAsync() && "async executables are not yet supported");
670
671 // Keep track of memory address to tensor mapping for result conversion.
672 TensorflowConversionContext ctx(operands.size(), results.size());
673 for (auto& t : operands)
674 ctx.runtime_tensors.insert({t.tensor().data(), &t.tensor()});
675
676 TensorflowResultConverter converter(results, ctx);
677
678 // Get the worker threads from the execution context.
679 Expected<Eigen::ThreadPoolInterface*> worker_threads =
680 GetWorkerThreads(exec_ctx);
681
682 if (LLVM_UNLIKELY(!worker_threads))
683 return ReturnErrors(results, worker_threads.takeError(), exec_ctx);
684
685 // Use Eigen thread pool to execute all async tasks.
686 EigenThreadPoolAsyncTaskRunner async_task_runner(*worker_threads);
687
688 Executable::ExecuteOpts opts;
689 opts.async_task_runner = &async_task_runner;
690
691 // Execution error automatically forwarded to all results, we only need to
692 // notify the HostContext to emit the diagnostics for the kernel invocation.
693 auto err = executable.Execute(memrefs, converter, opts);
694 if (LLVM_UNLIKELY(err)) {
695 EmitError(exec_ctx, StrCat(err));
696 return;
697 }
698 }
699
700 // Gets a specialized Executable async value from the JitExecutable, and then
701 // dispatches it inline or using and-then continuation depending on the async
702 // value state.
ExecuteImpl(JitExecutable & jit_executable,RepeatedArguments<FallbackTensor> operands,RemainingResults results,const ExecutionContext & exec_ctx,bool debug)703 static void ExecuteImpl(JitExecutable& jit_executable,
704 RepeatedArguments<FallbackTensor> operands,
705 RemainingResults results,
706 const ExecutionContext& exec_ctx, bool debug) {
707 // Convert Tensor operands to memref descriptors.
708 auto memrefs = ConvertTensorOperandsToMemrefDesc(operands);
709
710 // Get an executable that might be specialized to the operands.
711 DebugListener debug_listener;
712
713 // Pass request context to the compilation task runner.
714 JitExecutable::UserData user_data = exec_ctx.request_ctx();
715
716 Expected<AsyncValuePtr<Executable>> executable = jit_executable.GetExecutable(
717 memrefs, user_data, debug ? &debug_listener : nullptr);
718
719 if (LLVM_UNLIKELY(!executable))
720 return ReturnErrors(results, executable.takeError(), exec_ctx);
721
722 // If executable is available execute it inline ...
723 if (LLVM_LIKELY(executable->IsConcrete()))
724 return ExecuteImpl(executable->get(), memrefs, operands, results, exec_ctx);
725
726 // ... or maybe return errors.
727 if (LLVM_UNLIKELY(executable->IsError()))
728 return ReturnErrors(results, executable->GetError(), exec_ctx);
729
730 // Otherwise execute it when the executable will become available. This
731 // requires careful lifetime extension of all async values passed as operands
732 // to the kernel (and also results that will become available asynchronously).
733
734 // Allocate indirect async values for all results, we'll forward them to the
735 // actual async values computed by the executable later.
736 for (unsigned i = 0; i < results.size(); ++i)
737 results.AllocateIndirectResultAt(i);
738
739 // Call executable when it's ready with the original operands.
740 executable->AndThen([exec_ctx, executable = *executable,
741 memrefs = std::move(memrefs),
742 r = RCArray<AsyncValue>(results.values()),
743 o = RCArray<AsyncValue>(operands.values())] {
744 // Allocate storage for the executable results.
745 llvm::SmallVector<RCReference<AsyncValue>> results_storage;
746 results_storage.resize(r.size());
747
748 // Reconstruct arguments and results from captured async values.
749 RepeatedArguments<FallbackTensor> operands(o.values());
750 RemainingResults results(results_storage);
751
752 if (executable.IsError()) {
753 ReturnErrors(results, executable.GetError(), exec_ctx);
754 } else {
755 ExecuteImpl(*executable, memrefs, operands, results, exec_ctx);
756 }
757
758 // Forward previously allocated indirect results to the actual results.
759 for (unsigned i = 0; i < r.size(); ++i)
760 llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
761 std::move(results_storage[i]));
762 });
763 }
764
OperandsToString(RepeatedArguments<FallbackTensor> operands)765 static std::string OperandsToString(
766 RepeatedArguments<FallbackTensor> operands) {
767 std::string out;
768 llvm::raw_string_ostream os(out);
769 int i = 0;
770 os << "{";
771 for (const auto& operand : operands) {
772 os << "[" << i++ << "]: " << operand.tensor().DebugString(/*num_values=*/0);
773 if (i < operands.size()) os << ", ";
774 }
775 os << "}";
776 return out;
777 }
778
779 // Gets a JitExecutable async value from the cache, and then dispatches it
780 // inline or using and-then continuation depending on the async value state.
ExecuteImpl(RepeatedArguments<FallbackTensor> operands,RemainingResults results,const StringAttribute & device,const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,bool debug,const Optional<TfJitRtPipelineOpts> & opts)781 static void ExecuteImpl(RepeatedArguments<FallbackTensor> operands,
782 RemainingResults results, const StringAttribute& device,
783 const CompilationUnitAttribute& kernel,
784 const ExecutionContext& exec_ctx, bool debug,
785 const Optional<TfJitRtPipelineOpts>& opts) {
786 VLOG(2) << "kernel_name: " << kernel.root_symbol().str()
787 << ", operands: " << OperandsToString(operands);
788
789 // Compile kernel module into the JitExecutable.
790 Expected<AsyncValuePtr<JitExecutable>> jit_executable =
791 CompileImpl(kernel, exec_ctx, opts);
792
793 if (LLVM_UNLIKELY(!jit_executable))
794 return ReturnErrors(results, jit_executable.takeError(), exec_ctx);
795
796 // If kernel is available execute it inline ...
797 if (LLVM_LIKELY(jit_executable->IsConcrete()))
798 return ExecuteImpl(**jit_executable, operands, results, exec_ctx, debug);
799
800 // ... or maybe return errors.
801 if (LLVM_UNLIKELY(jit_executable->IsError()))
802 return ReturnErrors(results, jit_executable->GetError(), exec_ctx);
803
804 // Otherwise execute it when the executable will become available. This
805 // requires careful lifetime extension of all async values passed as operands
806 // to the kernel (and also results that will become available asynchronously).
807
808 // Allocate indirect async values for all results, we'll forward them to the
809 // actual async values computed by the executable later.
810 for (unsigned i = 0; i < results.size(); ++i)
811 results.AllocateIndirectResultAt(i);
812
813 // Call executable when it's ready with the original operands.
814 jit_executable->AndThen([exec_ctx, jit_executable = *jit_executable,
815 r = RCArray<AsyncValue>(results.values()),
816 o = RCArray<AsyncValue>(operands.values()), debug] {
817 // Allocate storage for compiled executable results.
818 llvm::SmallVector<RCReference<AsyncValue>> results_storage;
819 results_storage.resize(r.size());
820
821 // Reconstruct arguments and results from captured async values.
822 RepeatedArguments<FallbackTensor> operands(o.values());
823 RemainingResults results(results_storage);
824
825 if (jit_executable.IsError()) {
826 ReturnErrors(results, jit_executable.GetError(), exec_ctx);
827 } else {
828 ExecuteImpl(*jit_executable, operands, results, exec_ctx, debug);
829 }
830
831 // Forward previously entry indirect results to the actual results.
832 for (unsigned i = 0; i < r.size(); ++i)
833 llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
834 std::move(results_storage[i]));
835 });
836 }
837
ExecuteImplAndMaybeLogQueryOfDeath(RepeatedArguments<FallbackTensor> operands,RemainingResults results,const StringAttribute & device,const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,bool debug=false,const Optional<TfJitRtPipelineOpts> & opts=None)838 static void ExecuteImplAndMaybeLogQueryOfDeath(
839 RepeatedArguments<FallbackTensor> operands, RemainingResults results,
840 const StringAttribute& device, const CompilationUnitAttribute& kernel,
841 const ExecutionContext& exec_ctx, bool debug = false,
842 const Optional<TfJitRtPipelineOpts>& opts = None) {
843 if (LLVM_LIKELY(!GetJitRtFlags().log_query_of_death)) {
844 return ExecuteImpl(operands, results, device, kernel, exec_ctx, debug,
845 opts);
846 }
847 TfJitRtQueryOfDeathLogger qod_logger(/*kernel_name=*/kernel.root_symbol(),
848 /*kernel_serialized_operation=*/
849 kernel.serialized_operation(),
850 /*operands=*/OperandsToString(operands));
851 ExecuteImpl(operands, results, device, kernel, exec_ctx, debug, opts);
852 }
853
854 // -------------------------------------------------------------------------- //
855 // TFRT kernel function definitions for tf_jitrt.fallback.execute operations.
856 // -------------------------------------------------------------------------- //
857
858 // Compiles kernel into the JitExecutable and executes it with the fallback
859 // tensors operands.
Execute(RepeatedArguments<FallbackTensor> operands,RemainingResults results,StringAttribute device,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)860 static void Execute(RepeatedArguments<FallbackTensor> operands,
861 RemainingResults results, StringAttribute device,
862 CompilationUnitAttribute kernel,
863 const ExecutionContext& exec_ctx) {
864 ExecuteImplAndMaybeLogQueryOfDeath(operands, results, device, kernel,
865 exec_ctx);
866 }
867
868 // Compiles kernel into the JitExecutable and executes it with the fallback
869 // tensors operands in the debug mode: prints compilation diagnostics to the
870 // standard output. Should be used only in tests for verifying compiler
871 // internals.
ExecuteDebug(RepeatedArguments<FallbackTensor> operands,RemainingResults results,Attribute<bool> debug_specializations,StringAttribute device,CompilationUnitAttribute kernel,Attribute<bool> vectorize,Attribute<bool> legalize_i1_tensors,const ExecutionContext & exec_ctx)872 void ExecuteDebug(RepeatedArguments<FallbackTensor> operands,
873 RemainingResults results,
874 Attribute<bool> debug_specializations, StringAttribute device,
875 CompilationUnitAttribute kernel, Attribute<bool> vectorize,
876 Attribute<bool> legalize_i1_tensors,
877 const ExecutionContext& exec_ctx) {
878 TfJitRtPipelineOpts opts;
879 opts.vectorize = *vectorize;
880 opts.legalize_i1_tensors = *legalize_i1_tensors;
881 ExecuteImplAndMaybeLogQueryOfDeath(operands, results, device, kernel,
882 exec_ctx, *debug_specializations, opts);
883 }
884
885 } // namespace
886
RegisterTfJitRuntimeKernels(KernelRegistry * registry)887 void RegisterTfJitRuntimeKernels(KernelRegistry* registry) {
888 registry->AddKernel("tf_jitrt.fallback.compile", TFRT_KERNEL(Compile));
889 registry->AddKernel("tf_jitrt.fallback.execute", TFRT_KERNEL(Execute));
890 registry->AddKernel("tf_jitrt.fallback.debug.execute",
891 TFRT_KERNEL(ExecuteDebug));
892
893 registry->AddKernel("tf_jitrt.test.wait_for_compilation",
894 TFRT_KERNEL(WaitForCompilation));
895 registry->AddKernel("tf_jitrt.test.reset_compilation_thread_pool",
896 TFRT_KERNEL(ResetCompilationThreadPool));
897 }
898
899 } // namespace tensorflow
900