• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "mlir/Dialect/Async/IR/AsyncTypes.h"
23 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
24 #include "mlir/ExecutionEngine/AsyncRuntime.h"
25 #include "tensorflow/compiler/jit/flags.h"
26 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
28 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h"
29 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels_registration.h"
30 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
31 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_query_of_death.h"
32 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
33 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
34 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
35 #include "tensorflow/compiler/xla/runtime/arguments.h"
36 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
37 #include "tensorflow/compiler/xla/runtime/executable.h"
38 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
39 #include "tensorflow/compiler/xla/runtime/types.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/platform/dynamic_annotations.h"
43 #include "tensorflow/core/platform/threadpool.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
46 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
47 #include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
48 #include "tfrt/jitrt/results.h"  // from @tf_runtime
49 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
50 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
51 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
52 #include "tfrt/host_context/chain.h"  // from @tf_runtime
53 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
54 #include "tfrt/host_context/host_buffer.h"  // from @tf_runtime
55 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
56 #include "tfrt/host_context/kernel_registry.h"  // from @tf_runtime
57 #include "tfrt/host_context/kernel_utils.h"  // from @tf_runtime
58 #include "tfrt/host_context/shared_context.h"  // from @tf_runtime
59 #include "tfrt/support/error_util.h"  // from @tf_runtime
60 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
61 #include "tfrt/support/rc_array.h"  // from @tf_runtime
62 #include "tfrt/support/string_util.h"  // from @tf_runtime
63 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
64 #include "tfrt/tensor/tensor_shape.h"  // from @tf_runtime
65 
66 namespace tensorflow {
67 namespace {
68 
69 using ::std::any_cast;
70 
71 using ::llvm::cast;
72 using ::llvm::Expected;
73 using ::llvm::MutableArrayRef;
74 using ::llvm::None;
75 using ::llvm::Optional;
76 
77 using ::tfrt::Argument;
78 using ::tfrt::ArrayRef;
79 using ::tfrt::AsyncValue;
80 using ::tfrt::AsyncValuePtr;
81 using ::tfrt::AsyncValueRef;
82 using ::tfrt::Attribute;
83 using ::tfrt::Chain;
84 using ::tfrt::CompilationUnitAttribute;
85 using ::tfrt::DecodedDiagnostic;
86 using ::tfrt::DType;
87 using ::tfrt::EmitErrorAsync;
88 using ::tfrt::ExecutionContext;
89 using ::tfrt::HostContext;
90 using ::tfrt::Index;
91 using ::tfrt::IndirectAsyncValue;
92 using ::tfrt::KernelRegistry;
93 using ::tfrt::MakeAvailableAsyncValueRef;
94 using ::tfrt::MakeConstructedAsyncValueRef;
95 using ::tfrt::MakeErrorAsyncValueRef;
96 using ::tfrt::MakeStringError;
97 using ::tfrt::RCArray;
98 using ::tfrt::RCReference;
99 using ::tfrt::RemainingResults;
100 using ::tfrt::RepeatedArguments;
101 using ::tfrt::RequestContext;
102 using ::tfrt::SharedContext;
103 using ::tfrt::StrCat;
104 using ::tfrt::StringAttribute;
105 
106 using ::tfrt::jitrt::CompilationPipelineOptions;
107 using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
108 using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
109 using ::tfrt::jitrt::ReturnErrors;
110 using ::tfrt::jitrt::ReturnStridedMemref;
111 using ::tfrt::jitrt::ReturnValueConversion;
112 using ::tfrt::jitrt::StaticRemainingResultsConverter;
113 
114 using ::xla::runtime::ArgumentConstraint;
115 using ::xla::runtime::ArgumentsRef;
116 using ::xla::runtime::AsyncValuesCache;
117 using ::xla::runtime::EigenThreadPoolAsyncTaskRunner;
118 using ::xla::runtime::Executable;
119 using ::xla::runtime::JitExecutable;
120 using ::xla::runtime::MemrefDesc;
121 using ::xla::runtime::SpecializationListener;
122 
123 using ::tensorflow::profiler::TraceMe;
124 using ::tensorflow::profiler::TraceMeEncode;
125 using ::tensorflow::tfd::KernelFallbackCompatRequestState;
126 using ::tensorflow::tfrt_stub::FallbackTensor;
127 using ::tensorflow::thread::ThreadPool;
128 
129 template <typename T>
130 using KernelArgument = ::tfrt::Argument<T>;
131 
132 using JitExecutableCache = AsyncValuesCache<size_t, JitExecutable>;
133 
134 // -------------------------------------------------------------------------- //
135 // Dedicated thread pool for running compilation tasks.
136 // -------------------------------------------------------------------------- //
137 
138 class CompilationThreadPool : public SharedContext {
139  public:
CompilationThreadPool(HostContext * host)140   explicit CompilationThreadPool(HostContext* host) { Reset(); }
141 
Get(HostContext * host)142   static CompilationThreadPool& Get(HostContext* host) {
143     return host->GetOrCreateSharedContext<CompilationThreadPool>();
144   }
145 
146   template <typename Task>
Schedule(Task && task)147   void Schedule(Task&& task) {
148     // Because compilation tasks can capture move only types, and Tensorflow
149     // thread pool requires std::function tasks, we have to do manual memory
150     // management here.
151     auto ptr = std::make_unique<Task>(std::forward<Task>(task));
152     thread_pool_->Schedule([ptr = ptr.release()]() {
153       (*ptr)();
154       delete ptr;
155     });
156   }
157 
158   // This is an unsafe function intended only for use in tests. It is undefined
159   // behavior to call it concurrently with `Schedule`.
Reset()160   void Reset() {
161     thread_pool_ = std::make_unique<ThreadPool>(
162         Env::Default(), "tf-jitrt-compiler", /*num_threads=*/32);
163   }
164 
165  private:
166   std::unique_ptr<ThreadPool> thread_pool_;
167 };
168 
169 // -------------------------------------------------------------------------- //
170 // JIT compiled kernels use Eigen ThreadPool managed by the kernel fallback as
171 // an async runtime worker threads.
172 // -------------------------------------------------------------------------- //
173 
GetWorkerThreads(const ExecutionContext & exec_ctx)174 static Expected<Eigen::ThreadPoolInterface*> GetWorkerThreads(
175     const ExecutionContext& exec_ctx) {
176   RequestContext* req_ctx = exec_ctx.request_ctx();
177 
178   auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
179   if (LLVM_UNLIKELY(!fallback))
180     return MakeStringError("fallback request state was not found");
181 
182   // Return user provided intra op thread pool if it is available.
183   if (LLVM_LIKELY(fallback->intra_op_threadpool()))
184     return fallback->intra_op_threadpool();
185 
186   // Otherwise find the default CPU device in the device manager.
187   Device* host_cpu = fallback->device_manager().HostCPU();
188   assert(host_cpu && "fallback state must have a valid host cpu device");
189 
190   const Eigen::ThreadPoolDevice* eigen = host_cpu->eigen_cpu_device();
191   assert(eigen && "host cpu device must have a valid Eigen thread pool device");
192 
193   return eigen->getPool();
194 }
195 
196 // -------------------------------------------------------------------------- //
197 // Compile compilation unit attribute to an executable result.
198 // -------------------------------------------------------------------------- //
199 
200 // Options for the `tf-jitrt-pipeline`. We do not use MLIR pass options directly
201 // because they are not copyable or movable, and we need to pass them cheaply
202 // across the async compilation tasks boundary.
203 struct TfJitRtPipelineOpts {
204   bool vectorize;
205   bool legalize_i1_tensors;
206 };
207 
208 // Prints memref descriptor as a tensor type: tensor<NxMxf32>.
AsTensorType(const MemrefDesc & desc)209 static std::string AsTensorType(const MemrefDesc& desc) {
210   std::string str;
211   llvm::raw_string_ostream os(str);
212 
213   os << "tensor<";
214   for (size_t size : desc.sizes()) os << size << "x";
215   os << desc.dtype();
216   os << ">";
217 
218   return str;
219 }
220 
221 // Print memref descriptor content to trace value specializations.
AsTensorContent(const MemrefDesc & desc)222 static std::string AsTensorContent(const MemrefDesc& desc) {
223   std::string str;
224   llvm::raw_string_ostream os(str);
225 
226   auto print_0d = [&](auto type_tag) {
227     os << desc.dtype() << ": "
228        << *static_cast<decltype(type_tag)*>(desc.data());
229   };
230 
231   auto print_1d = [&](auto type_tag) {
232     os << desc.dtype() << ": [";
233     for (size_t i = 0; i < desc.size(0); ++i) {
234       if (i != 0) os << ",";
235       os << static_cast<decltype(type_tag)*>(desc.data())[i];
236     }
237     os << "]";
238   };
239 
240   auto type_dispatch = [&](auto functor) {
241     switch (desc.dtype()) {
242       case DType::I32:
243         functor(int32_t{});
244         break;
245       case DType::I64:
246         functor(int64_t{});
247         break;
248       default:
249         os << "<unsupported dtype " << desc.dtype() << ">";
250     }
251   };
252 
253   size_t rank = desc.rank();
254 
255   switch (rank) {
256     case 0:
257       type_dispatch(print_0d);
258       break;
259     case 1:
260       type_dispatch(print_1d);
261       break;
262     default:
263       os << "<unsupported rank " << desc.rank() << ">";
264   }
265 
266   return str;
267 }
268 
269 // Gets the session name from the fallback request state.
GetSessionName(RequestContext * req_ctx)270 static const std::string GetSessionName(RequestContext* req_ctx) {
271   auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
272   if (!fallback) return "<unknown>";
273 
274   return fallback->session_metadata().name();
275 }
276 
CompileImpl(const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,const Optional<TfJitRtPipelineOpts> & opts=None)277 static Expected<AsyncValuePtr<JitExecutable>> CompileImpl(
278     const CompilationUnitAttribute& kernel, const ExecutionContext& exec_ctx,
279     const Optional<TfJitRtPipelineOpts>& opts = None) {
280   // Request context must be initialized with the tf_jitrt state.
281   auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
282   if (LLVM_UNLIKELY(!state))
283     return MakeStringError("tf_jitrt state not found in the request context");
284 
285   // We rely on the unique `id` provided by the CompilationUnitAttribute to look
286   // up the JitExecutable in the cache. This id is guaranteed to be unique
287   // within a Bef file. Currently we rely on the fact that the SavedModel
288   // never unloads a Bef file, and there is a 1-to-1 relationship between the
289   // ResourceContext and the SavedModel.
290   //
291   // TODO(b/206081322): Different compilation options should create unique
292   // compiled kernel cache keys.
293   size_t key = kernel.id();
294 
295   JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
296 
297   // Maybe return JitExecutable from the cache.
298   auto cached = jit_executable_cache->Find(key);
299   if (LLVM_LIKELY(cached)) return cached;
300 
301   // Get the worker threads from the execution context. Do this before
302   // allocating an async value to make sure that we can try to instantiate the
303   // executable.
304   Expected<Eigen::ThreadPoolInterface*> worker_threads =
305       GetWorkerThreads(exec_ctx);
306   if (auto err = worker_threads.takeError()) return std::move(err);
307 
308   // Allocate a placeholder for the compiled JitExecutable.
309   JitExecutableCache::Entry entry = jit_executable_cache->Allocate(key);
310 
311   // We lost the race; some other invocation will do the compilation.
312   if (!entry.allocated) return entry.ptr;
313 
314   // Given that compilation happens asynchronously, passing (or capturing) these
315   // by value prevents use-after-free errors.
316   struct KernelInfo {
317     intptr_t id;
318     std::string entrypoint;
319     std::string name;
320     std::string serialized_operation;
321   } kernel_info;
322 
323   // We only support functions nested in top level compiled module.
324   if (kernel.nested_symbols().size() != 1)
325     return MakeStringError(
326         "kernel function has to be defined in a top-level module");
327 
328   // TODO(ecg): use designed initializers + const when C++20 is adopted.
329   kernel_info.id = kernel.id();
330   kernel_info.entrypoint = kernel.nested_symbols()[0];
331   kernel_info.name = kernel.root_symbol();
332   kernel_info.serialized_operation = kernel.serialized_operation();
333 
334   // Compilation (specialized executable compilation) events should be rare, so
335   // we can afford to do detailed tracing for every compilation. If compilation
336   // events happen too often, it is a much larger problem than the excessive
337   // tracing.
338 
339   // Custom runner for compiling specializations that schedules compilation task
340   // into the dedicated thread pool and adds tracing.
341   auto runner = [kernel_info](size_t specialization,
342                               ArrayRef<ArgumentConstraint> constraints,
343                               ArgumentsRef arguments,
344                               JitExecutable::CompilationTask compile,
345                               JitExecutable::UserData user_data) {
346     assert(arguments.size() == constraints.size());
347 
348     // Get the context of the request that triggered specialization compilation.
349     RequestContext* req_ctx = any_cast<RequestContext*>(user_data);
350     HostContext* host = req_ctx->host();
351 
352     // Prepare arguments for the compilation tracing in the caller thread,
353     // because operands lifetime is shorter than the compilation task.
354     using SpecializationArg = std::pair<std::string, std::string>;
355     llvm::SmallVector<SpecializationArg> args;
356     args.reserve(arguments.size());
357 
358     // Trace types of all operands of the specialization.
359     for (size_t i = 0; i < arguments.size(); ++i)
360       args.emplace_back(StrCat("%arg", i, " type"),
361                         AsTensorType(cast<MemrefDesc>(arguments[i])));
362 
363     // Trace content of all operands that require value specializations.
364     for (size_t i = 0; i < constraints.size(); ++i) {
365       if (constraints[i] != ArgumentConstraint::kValue) continue;
366       args.emplace_back(StrCat("%arg", i, " value"),
367                         AsTensorContent(cast<MemrefDesc>(arguments[i])));
368     }
369 
370     // Schedule specialization compilation task into the dedicated thread pool.
371     CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
372 
373     thread_pool.Schedule(
374         [kernel_info, specialization, request_id = req_ctx->id(),
375          session_name = GetSessionName(req_ctx), compile = std::move(compile),
376          args = std::move(args)]() mutable {
377           TraceMe trace_me([&] {
378             return TraceMeEncode("tf_jitrt.CompileSpecialization",
379                                  {{"id", request_id},
380                                   {"kernel_id", kernel_info.id},
381                                   {"executable", kernel_info.name},
382                                   {"specialization", specialization}});
383           });
384 
385           for (SpecializationArg& arg : args) {
386             trace_me.AppendMetadata([&] {
387               return TraceMeEncode({{arg.first, arg.second}});
388             });
389           }
390 
391           trace_me.AppendMetadata([&] {
392             return TraceMeEncode({{"src", kernel_info.serialized_operation}});
393           });
394 
395           auto compile_start_time = absl::Now();
396           LOG(INFO) << "Started JitExecutable specialization compilation for "
397                     << kernel_info.name << " (" << session_name << ")";
398           compile();
399           auto compile_duration = absl::Now() - compile_start_time;
400 
401           LOG(INFO) << "JitExecutable specialization compilation for "
402                     << kernel_info.name << " took "
403                     << absl::ToInt64Milliseconds(compile_duration) << " ms ("
404                     << session_name << ")";
405 
406           if (compile_duration > absl::Seconds(1))
407             LOG(INFO) << "Expensive JitExecutable specialization compilation ("
408                       << absl::ToInt64Milliseconds(compile_duration)
409                       << " ms):\n"
410                       << kernel_info.serialized_operation;
411 
412           RecordCompileTime(session_name, kernel_info.name, specialization,
413                             compile_duration);
414         });
415   };
416 
417   HostContext* host = exec_ctx.host();
418   RequestContext* req_ctx = exec_ctx.request_ctx();
419 
420   // Compile kernel asynchronously in the compilation thread pool.
421   CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
422 
423   thread_pool.Schedule([kernel_info, runner, workers = *worker_threads,
424                         ref = entry.ptr.CopyRef(), request_id = req_ctx->id(),
425                         session_name = GetSessionName(req_ctx),
426                         tf_jitrt_opts = opts]() {
427     TraceMe trace_me([&] {
428       return TraceMeEncode("tf_jitrt.CompileDefault",
429                            {{"id", request_id},
430                             {"kernel_id", kernel_info.id},
431                             {"executable", kernel_info.name},
432                             {"src", kernel_info.serialized_operation}});
433     });
434 
435     // Options for the default JitRt compilation pipeline (lowering to LLVM).
436     CompilationPipelineOptions copts;
437     copts.alignment = EIGEN_MAX_ALIGN_BYTES;  // Eigen included by tensor.h
438     copts.num_worker_threads = workers->NumThreads();
439     copts.cost_driven_async_parallel_for =
440         GetJitRtFlags().cost_driven_async_parallel_for;
441 
442     // Options for the JitRt JitExecutable compilation.
443     JitExecutable::Options opts;
444     opts.specialization = GetJitRtFlags().always_specialize
445                               ? JitExecutable::Specialization::kAlways
446                               : JitExecutable::Specialization::kEnabled;
447 
448     // Register dialects and interfaces required for the compilation pipeline.
449     opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
450       mlir::RegisterAllTensorFlowDialects(registry);
451       RegisterDefaultJitRtDialects(registry);
452     };
453 
454     // Register a custom pipeline for lowering from Tensorflow dialect to LLVM.
455     opts.compiler.create_compilation_pipeline = [=](mlir::PassManager& pm) {
456       if (GetJitRtFlags().enable_crash_reproducer)
457         SetCrashReproducer(pm, kCrashReproducerStdErr);
458 
459       TfJitRtPipelineOptions opts;
460       if (tf_jitrt_opts) {
461         opts.vectorize = tf_jitrt_opts->vectorize;
462         opts.legalize_i1_tensors = tf_jitrt_opts->legalize_i1_tensors;
463       } else {
464         opts.vectorize = GetJitRtFlags().vectorize;
465       }
466 
467       // Lower from Tensorflow to Linalg on buffers.
468       CreateTfJitRtPipeline(pm, opts);
469 
470       // Use default JitRt compilation pipeline to lower to LLVM.
471       CreateDefaultJitRtCompilationPipeline(pm, copts);
472     };
473 
474     // Register a custom pipeline to propagate specialization information.
475     opts.compiler.create_specialization_pipeline = [=](mlir::PassManager& pm) {
476       if (GetJitRtFlags().enable_crash_reproducer)
477         SetCrashReproducer(pm, kCrashReproducerStdErr);
478       CreateJitRtSpecializationPipeline(pm);
479     };
480 
481     // When lowering Tensorflow functions to JitRt we convert all input and
482     // result tensors to memrefs, and add a kernel context input.
483     opts.compiler.calling_convention = xla::runtime::DefaultCallingConvention(
484         mlir::bufferization::BufferizeTypeConverter());
485 
486     // Instantiate new JitExecutable from the MLIR source.
487     auto compile_start_time = absl::Now();
488     LOG(INFO) << "Started JitExecutable instantiation compilation for "
489               << kernel_info.name << " (" << session_name << ")";
490     Expected<JitExecutable> jit_executable = JitExecutable::Instantiate(
491         kernel_info.serialized_operation, kernel_info.entrypoint,
492         std::move(opts), session_name, runner);
493     auto compile_duration = absl::Now() - compile_start_time;
494 
495     LOG(INFO) << "JitExecutable instantiation for " << kernel_info.name
496               << " took " << absl::ToInt64Milliseconds(compile_duration)
497               << " ms (" << session_name << ")";
498 
499     if (compile_duration > absl::Seconds(1))
500       LOG(INFO) << "Expensive JitExecutable instantiation ("
501                 << absl::ToInt64Milliseconds(compile_duration) << " ms):\n"
502                 << kernel_info.serialized_operation;
503 
504     RecordCompileTime(session_name, kernel_info.name, std::nullopt,
505                       compile_duration);
506 
507     // Set the entry async value state to error or concrete.
508     if (auto err = jit_executable.takeError())
509       ref.SetError(std::move(err));
510     else
511       ref.emplace(std::move(*jit_executable));
512   });
513 
514   return entry.ptr;
515 }
516 
517 // -------------------------------------------------------------------------- //
518 // TFRT kernel function definition for tf_jitrt.fallback.compile operation.
519 // -------------------------------------------------------------------------- //
520 
521 // Compiles kernel into the JitExecutable and updates JitExecutableCache.
Compile(StringAttribute device,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)522 static AsyncValueRef<Chain> Compile(StringAttribute device,
523                                     CompilationUnitAttribute kernel,
524                                     const ExecutionContext& exec_ctx) {
525   // Trigger kernel compilation, that will update the JitExecutableCache.
526   Expected<AsyncValuePtr<JitExecutable>> executable =
527       CompileImpl(kernel, exec_ctx);
528 
529   // Return error if can't schedule the compilation task.
530   if (auto err = executable.takeError())
531     return MakeErrorAsyncValueRef(StrCat(err));
532 
533   // Mark chain available once we compile the default executable.
534   auto chain = MakeConstructedAsyncValueRef<Chain>();
535   executable->AndThen([chain]() { chain.SetStateConcrete(); });
536 
537   return chain;
538 }
539 
540 // -------------------------------------------------------------------------- //
541 // TFRT kernel function definition for tf_jitrt.test.wait_for_compilation.
542 // -------------------------------------------------------------------------- //
543 
WaitForCompilation(KernelArgument<Chain> chain,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)544 static AsyncValueRef<Chain> WaitForCompilation(
545     KernelArgument<Chain> chain, CompilationUnitAttribute kernel,
546     const ExecutionContext& exec_ctx) {
547   // Request context must be initialized with the tf_jitrt state.
548   auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
549   if (!state)
550     return EmitErrorAsync(exec_ctx,
551                           "tf_jitrt state not found in the request context");
552 
553   // Wait for the completion of all compilation tasks.
554   JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
555   if (auto cached = jit_executable_cache->Find(kernel.id()))
556     return cached->AllExecutablesCompiled();
557 
558   return MakeAvailableAsyncValueRef<Chain>();
559 }
560 
561 // -------------------------------------------------------------------------- //
562 // TFRT kernel function for tf_jitrt.test.reset_compilation_thread_pool.
563 // -------------------------------------------------------------------------- //
564 
ResetCompilationThreadPool(KernelArgument<Chain> chain,const ExecutionContext & exec_ctx)565 static AsyncValueRef<Chain> ResetCompilationThreadPool(
566     KernelArgument<Chain> chain, const ExecutionContext& exec_ctx) {
567   // Make sure that we reset the compilation thread pool only from a thread pool
568   // (concurrent work queue) managed by the HostContext.
569   return EnqueueWork(exec_ctx, [host = exec_ctx.host()]() -> Chain {
570     CompilationThreadPool::Get(host).Reset();
571     return {};
572   });
573 }
574 
575 // -------------------------------------------------------------------------- //
576 // Execute compiled JitRt kernels with Fallback Runtime interop.
577 // -------------------------------------------------------------------------- //
578 
579 using ReturnTensorflowTensor =
580     ReturnValueConversion<TensorflowConversionContext,
581                           ReturnStridedMemref<ConvertTensor>>;
582 
583 using TensorflowResultConverter =
584     StaticRemainingResultsConverter<TensorflowConversionContext,
585                                     ReturnTensorflowTensor>;
586 
ConvertTensorToMemrefDesc(const tensorflow::Tensor & tensor)587 static MemrefDesc ConvertTensorToMemrefDesc(const tensorflow::Tensor& tensor) {
588   // Fills memref sizes and strides with a tensor shape;
589   auto fill_desc = [&](MutableArrayRef<Index> sizes,
590                        MutableArrayRef<Index> strides) {
591     int64_t multiplier = 1;
592     for (int i = tensor.dims() - 1; i >= 0; --i) {
593       int64_t dim_size = tensor.dim_size(i);
594       sizes[i] = dim_size;
595       strides[i] = multiplier;
596       multiplier *= dim_size;
597     }
598   };
599 
600   return MemrefDesc(tensor.dims(), tfd::GetTfrtDtype(tensor.dtype()),
601                     const_cast<void*>(tensor.data()), 0, fill_desc);
602 }
603 
ConvertTensorOperandsToMemrefDesc(RepeatedArguments<FallbackTensor> operands)604 static std::vector<MemrefDesc> ConvertTensorOperandsToMemrefDesc(
605     RepeatedArguments<FallbackTensor> operands) {
606   std::vector<MemrefDesc> memrefs;
607   memrefs.reserve(operands.size());
608   for (FallbackTensor& operand : operands)
609     memrefs.emplace_back(ConvertTensorToMemrefDesc(operand.tensor()));
610   return memrefs;
611 }
612 
613 struct DebugListener : public SpecializationListener {
notifyModuleSpecializedtensorflow::__anon5bd986ca0111::DebugListener614   void notifyModuleSpecialized(
615       ArrayRef<mlir::Type> operands,
616       ArrayRef<mlir::DictionaryAttr> attrs) const override {
617     std::string message;
618     llvm::raw_string_ostream os(message);
619     os << "Specialized operands:\n";
620     for (auto& tuple : llvm::enumerate(llvm::zip(operands, attrs))) {
621       mlir::Type type = std::get<0>(tuple.value());
622       mlir::Attribute attr = std::get<1>(tuple.value());
623       os << "%arg" << tuple.index() << ": " << type << " " << attr << "\n";
624     }
625     printf("%s", message.c_str());
626     fflush(stdout);
627   }
628 
notifyValueSpecializedtensorflow::__anon5bd986ca0111::DebugListener629   void notifyValueSpecialized(unsigned index, mlir::Type type,
630                               mlir::Attribute value) const override {
631     std::string message;
632     llvm::raw_string_ostream(message) << "%arg" << index << " "
633                                       << "value specialized: " << value << "\n";
634     printf("%s", message.c_str());
635     fflush(stdout);
636   }
637 };
638 
639 // Emits diagnostics for the kernel invocation and returns error for all
640 // remaining results.
641 template <typename Error>
ReturnErrors(RemainingResults results,Error error,const ExecutionContext & exec_ctx)642 static void ReturnErrors(RemainingResults results, Error error,
643                          const ExecutionContext& exec_ctx) {
644   EmitError(exec_ctx, StrCat(error));
645   ReturnErrors(results, std::move(error));
646 }
647 
ExecuteImpl(Executable & executable,ArrayRef<MemrefDesc> memrefs,RepeatedArguments<FallbackTensor> operands,RemainingResults results,const ExecutionContext & exec_ctx)648 static void ExecuteImpl(Executable& executable, ArrayRef<MemrefDesc> memrefs,
649                         RepeatedArguments<FallbackTensor> operands,
650                         RemainingResults results,
651                         const ExecutionContext& exec_ctx) {
652   // Bind execution trace to the request context.
653   TraceMe trace_me([&] {
654     int64_t id = exec_ctx.request_ctx()->id();
655     absl::string_view name(executable.name().data(), executable.name().size());
656     return TraceMeEncode(
657         "tf_jitrt.Execute",
658         {{"id", id},
659          {"executable", name},
660          {"specialization", !executable.specialization().has_value()
661                                 ? "default"
662                                 : std::to_string(*executable.specialization())},
663          {"time_to_compile_ms", executable.time_to_compile().count()}});
664   });
665 
666   // TODO(ezhulenev): Conversion context and async task runner might not outlive
667   // the execution of all async tasks, and should be kept alive until all tasks
668   // are completed, which will require heap allocation(s).
669   assert(!executable.IsAsync() && "async executables are not yet supported");
670 
671   // Keep track of memory address to tensor mapping for result conversion.
672   TensorflowConversionContext ctx(operands.size(), results.size());
673   for (auto& t : operands)
674     ctx.runtime_tensors.insert({t.tensor().data(), &t.tensor()});
675 
676   TensorflowResultConverter converter(results, ctx);
677 
678   // Get the worker threads from the execution context.
679   Expected<Eigen::ThreadPoolInterface*> worker_threads =
680       GetWorkerThreads(exec_ctx);
681 
682   if (LLVM_UNLIKELY(!worker_threads))
683     return ReturnErrors(results, worker_threads.takeError(), exec_ctx);
684 
685   // Use Eigen thread pool to execute all async tasks.
686   EigenThreadPoolAsyncTaskRunner async_task_runner(*worker_threads);
687 
688   Executable::ExecuteOpts opts;
689   opts.async_task_runner = &async_task_runner;
690 
691   // Execution error automatically forwarded to all results, we only need to
692   // notify the HostContext to emit the diagnostics for the kernel invocation.
693   auto err = executable.Execute(memrefs, converter, opts);
694   if (LLVM_UNLIKELY(err)) {
695     EmitError(exec_ctx, StrCat(err));
696     return;
697   }
698 }
699 
700 // Gets a specialized Executable async value from the JitExecutable, and then
701 // dispatches it inline or using and-then continuation depending on the async
702 // value state.
ExecuteImpl(JitExecutable & jit_executable,RepeatedArguments<FallbackTensor> operands,RemainingResults results,const ExecutionContext & exec_ctx,bool debug)703 static void ExecuteImpl(JitExecutable& jit_executable,
704                         RepeatedArguments<FallbackTensor> operands,
705                         RemainingResults results,
706                         const ExecutionContext& exec_ctx, bool debug) {
707   // Convert Tensor operands to memref descriptors.
708   auto memrefs = ConvertTensorOperandsToMemrefDesc(operands);
709 
710   // Get an executable that might be specialized to the operands.
711   DebugListener debug_listener;
712 
713   // Pass request context to the compilation task runner.
714   JitExecutable::UserData user_data = exec_ctx.request_ctx();
715 
716   Expected<AsyncValuePtr<Executable>> executable = jit_executable.GetExecutable(
717       memrefs, user_data, debug ? &debug_listener : nullptr);
718 
719   if (LLVM_UNLIKELY(!executable))
720     return ReturnErrors(results, executable.takeError(), exec_ctx);
721 
722   // If executable is available execute it inline ...
723   if (LLVM_LIKELY(executable->IsConcrete()))
724     return ExecuteImpl(executable->get(), memrefs, operands, results, exec_ctx);
725 
726   // ... or maybe return errors.
727   if (LLVM_UNLIKELY(executable->IsError()))
728     return ReturnErrors(results, executable->GetError(), exec_ctx);
729 
730   // Otherwise execute it when the executable will become available. This
731   // requires careful lifetime extension of all async values passed as operands
732   // to the kernel (and also results that will become available asynchronously).
733 
734   // Allocate indirect async values for all results, we'll forward them to the
735   // actual async values computed by the executable later.
736   for (unsigned i = 0; i < results.size(); ++i)
737     results.AllocateIndirectResultAt(i);
738 
739   // Call executable when it's ready with the original operands.
740   executable->AndThen([exec_ctx, executable = *executable,
741                        memrefs = std::move(memrefs),
742                        r = RCArray<AsyncValue>(results.values()),
743                        o = RCArray<AsyncValue>(operands.values())] {
744     // Allocate storage for the executable results.
745     llvm::SmallVector<RCReference<AsyncValue>> results_storage;
746     results_storage.resize(r.size());
747 
748     // Reconstruct arguments and results from captured async values.
749     RepeatedArguments<FallbackTensor> operands(o.values());
750     RemainingResults results(results_storage);
751 
752     if (executable.IsError()) {
753       ReturnErrors(results, executable.GetError(), exec_ctx);
754     } else {
755       ExecuteImpl(*executable, memrefs, operands, results, exec_ctx);
756     }
757 
758     // Forward previously allocated indirect results to the actual results.
759     for (unsigned i = 0; i < r.size(); ++i)
760       llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
761           std::move(results_storage[i]));
762   });
763 }
764 
OperandsToString(RepeatedArguments<FallbackTensor> operands)765 static std::string OperandsToString(
766     RepeatedArguments<FallbackTensor> operands) {
767   std::string out;
768   llvm::raw_string_ostream os(out);
769   int i = 0;
770   os << "{";
771   for (const auto& operand : operands) {
772     os << "[" << i++ << "]: " << operand.tensor().DebugString(/*num_values=*/0);
773     if (i < operands.size()) os << ", ";
774   }
775   os << "}";
776   return out;
777 }
778 
779 // Gets a JitExecutable async value from the cache, and then dispatches it
780 // inline or using and-then continuation depending on the async value state.
ExecuteImpl(RepeatedArguments<FallbackTensor> operands,RemainingResults results,const StringAttribute & device,const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,bool debug,const Optional<TfJitRtPipelineOpts> & opts)781 static void ExecuteImpl(RepeatedArguments<FallbackTensor> operands,
782                         RemainingResults results, const StringAttribute& device,
783                         const CompilationUnitAttribute& kernel,
784                         const ExecutionContext& exec_ctx, bool debug,
785                         const Optional<TfJitRtPipelineOpts>& opts) {
786   VLOG(2) << "kernel_name: " << kernel.root_symbol().str()
787           << ", operands: " << OperandsToString(operands);
788 
789   // Compile kernel module into the JitExecutable.
790   Expected<AsyncValuePtr<JitExecutable>> jit_executable =
791       CompileImpl(kernel, exec_ctx, opts);
792 
793   if (LLVM_UNLIKELY(!jit_executable))
794     return ReturnErrors(results, jit_executable.takeError(), exec_ctx);
795 
796   // If kernel is available execute it inline ...
797   if (LLVM_LIKELY(jit_executable->IsConcrete()))
798     return ExecuteImpl(**jit_executable, operands, results, exec_ctx, debug);
799 
800   // ... or maybe return errors.
801   if (LLVM_UNLIKELY(jit_executable->IsError()))
802     return ReturnErrors(results, jit_executable->GetError(), exec_ctx);
803 
804   // Otherwise execute it when the executable will become available. This
805   // requires careful lifetime extension of all async values passed as operands
806   // to the kernel (and also results that will become available asynchronously).
807 
808   // Allocate indirect async values for all results, we'll forward them to the
809   // actual async values computed by the executable later.
810   for (unsigned i = 0; i < results.size(); ++i)
811     results.AllocateIndirectResultAt(i);
812 
813   // Call executable when it's ready with the original operands.
814   jit_executable->AndThen([exec_ctx, jit_executable = *jit_executable,
815                            r = RCArray<AsyncValue>(results.values()),
816                            o = RCArray<AsyncValue>(operands.values()), debug] {
817     // Allocate storage for compiled executable results.
818     llvm::SmallVector<RCReference<AsyncValue>> results_storage;
819     results_storage.resize(r.size());
820 
821     // Reconstruct arguments and results from captured async values.
822     RepeatedArguments<FallbackTensor> operands(o.values());
823     RemainingResults results(results_storage);
824 
825     if (jit_executable.IsError()) {
826       ReturnErrors(results, jit_executable.GetError(), exec_ctx);
827     } else {
828       ExecuteImpl(*jit_executable, operands, results, exec_ctx, debug);
829     }
830 
831     // Forward previously entry indirect results to the actual results.
832     for (unsigned i = 0; i < r.size(); ++i)
833       llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
834           std::move(results_storage[i]));
835   });
836 }
837 
ExecuteImplAndMaybeLogQueryOfDeath(RepeatedArguments<FallbackTensor> operands,RemainingResults results,const StringAttribute & device,const CompilationUnitAttribute & kernel,const ExecutionContext & exec_ctx,bool debug=false,const Optional<TfJitRtPipelineOpts> & opts=None)838 static void ExecuteImplAndMaybeLogQueryOfDeath(
839     RepeatedArguments<FallbackTensor> operands, RemainingResults results,
840     const StringAttribute& device, const CompilationUnitAttribute& kernel,
841     const ExecutionContext& exec_ctx, bool debug = false,
842     const Optional<TfJitRtPipelineOpts>& opts = None) {
843   if (LLVM_LIKELY(!GetJitRtFlags().log_query_of_death)) {
844     return ExecuteImpl(operands, results, device, kernel, exec_ctx, debug,
845                        opts);
846   }
847   TfJitRtQueryOfDeathLogger qod_logger(/*kernel_name=*/kernel.root_symbol(),
848                                        /*kernel_serialized_operation=*/
849                                        kernel.serialized_operation(),
850                                        /*operands=*/OperandsToString(operands));
851   ExecuteImpl(operands, results, device, kernel, exec_ctx, debug, opts);
852 }
853 
854 // -------------------------------------------------------------------------- //
855 // TFRT kernel function definitions for tf_jitrt.fallback.execute operations.
856 // -------------------------------------------------------------------------- //
857 
858 // Compiles kernel into the JitExecutable and executes it with the fallback
859 // tensors operands.
Execute(RepeatedArguments<FallbackTensor> operands,RemainingResults results,StringAttribute device,CompilationUnitAttribute kernel,const ExecutionContext & exec_ctx)860 static void Execute(RepeatedArguments<FallbackTensor> operands,
861                     RemainingResults results, StringAttribute device,
862                     CompilationUnitAttribute kernel,
863                     const ExecutionContext& exec_ctx) {
864   ExecuteImplAndMaybeLogQueryOfDeath(operands, results, device, kernel,
865                                      exec_ctx);
866 }
867 
868 // Compiles kernel into the JitExecutable and executes it with the fallback
869 // tensors operands in the debug mode: prints compilation diagnostics to the
870 // standard output. Should be used only in tests for verifying compiler
871 // internals.
ExecuteDebug(RepeatedArguments<FallbackTensor> operands,RemainingResults results,Attribute<bool> debug_specializations,StringAttribute device,CompilationUnitAttribute kernel,Attribute<bool> vectorize,Attribute<bool> legalize_i1_tensors,const ExecutionContext & exec_ctx)872 void ExecuteDebug(RepeatedArguments<FallbackTensor> operands,
873                   RemainingResults results,
874                   Attribute<bool> debug_specializations, StringAttribute device,
875                   CompilationUnitAttribute kernel, Attribute<bool> vectorize,
876                   Attribute<bool> legalize_i1_tensors,
877                   const ExecutionContext& exec_ctx) {
878   TfJitRtPipelineOpts opts;
879   opts.vectorize = *vectorize;
880   opts.legalize_i1_tensors = *legalize_i1_tensors;
881   ExecuteImplAndMaybeLogQueryOfDeath(operands, results, device, kernel,
882                                      exec_ctx, *debug_specializations, opts);
883 }
884 
885 }  // namespace
886 
RegisterTfJitRuntimeKernels(KernelRegistry * registry)887 void RegisterTfJitRuntimeKernels(KernelRegistry* registry) {
888   registry->AddKernel("tf_jitrt.fallback.compile", TFRT_KERNEL(Compile));
889   registry->AddKernel("tf_jitrt.fallback.execute", TFRT_KERNEL(Execute));
890   registry->AddKernel("tf_jitrt.fallback.debug.execute",
891                       TFRT_KERNEL(ExecuteDebug));
892 
893   registry->AddKernel("tf_jitrt.test.wait_for_compilation",
894                       TFRT_KERNEL(WaitForCompilation));
895   registry->AddKernel("tf_jitrt.test.reset_compilation_thread_pool",
896                       TFRT_KERNEL(ResetCompilationThreadPool));
897 }
898 
899 }  // namespace tensorflow
900