• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
16 
17 #include <optional>
18 #include <string>
19 
20 #include "llvm/ADT/StringRef.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/framework/logging.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/threadpool_interface.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
31 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h"
32 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
33 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
34 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
35 #include "tensorflow/core/tfrt/utils/error_util.h"
36 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
37 #include "tensorflow/core/tfrt/utils/tensor_util.h"
38 #include "tfrt/core_runtime/execute_op_impl.h"  // from @tf_runtime
39 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
40 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
41 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
42 #include "tfrt/host_context/chain.h"  // from @tf_runtime
43 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
44 #include "tfrt/host_context/kernel_registry.h"  // from @tf_runtime
45 #include "tfrt/host_context/sync_kernel_frame.h"  // from @tf_runtime
46 #include "tfrt/support/error_util.h"  // from @tf_runtime
47 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
48 #include "tfrt/support/pointer_util.h"  // from @tf_runtime
49 #include "tfrt/support/string_util.h"  // from @tf_runtime
50 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
51 #include "tfrt/tracing/tracing.h"  // from @tf_runtime
52 
53 namespace tensorflow {
54 namespace tfd {
55 namespace {
56 
57 using ::tfrt::AsyncValue;
58 using ::tfrt::AsyncValueRef;
59 using ::tfrt::Chain;
60 using ::tfrt::OpAttrsRef;
61 using ::tfrt::RCReference;
62 using ::tfrt::SmallVector;
63 using ::tfrt::string_view;
64 
65 constexpr char kOpKernelRunnerTableResourceName[] =
66     "OpKernelRunnerTableResourceName";
67 
68 constexpr char kOpKernelRunnerCacheResourceName[] =
69     "OpKernelRunnerCacheResourceName";
70 
71 constexpr char kFallbackResourceArray[] = "FallbackResourceArray";
72 
KernelFallbackEmitError(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tensorflow::Status & status)73 void KernelFallbackEmitError(
74     const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
75     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
76     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
77     const tensorflow::Status& status) {
78   // Set all results to error, with the correct TFRT error code according to the
79   // error propagated from runtime fallback execution.
80   auto error =
81       EmitErrorAsync(exec_ctx,
82                      tfrt::StrCat("error running kernel fallback kernel ",
83                                   op_name, ": ", status.error_message()),
84                      tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
85   for (auto& result : results) result = error.CopyRef();
86   if (op_chain) *op_chain = std::move(error);
87 }
88 
89 }  // namespace
90 
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,const tensorflow::DeviceMgr * device_manager,const tensorflow::ProcessFunctionLibraryRuntime * pflr,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata)91 Status SetUpKernelFallbackCompatRequestContext(
92     tfrt::RequestContextBuilder* builder,
93     const tensorflow::DeviceMgr* device_manager,
94     const tensorflow::ProcessFunctionLibraryRuntime* pflr,
95     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
96     const absl::optional<tfrt::ModelMetadata>& model_metadata) {
97   DCHECK(builder);
98   DCHECK(device_manager);
99   DCHECK(pflr);
100 
101   auto* runner_table =
102       builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
103           kOpKernelRunnerTableResourceName);
104 
105   auto* resource_array =
106       builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
107           kFallbackResourceArray);
108 
109   builder->context_data().emplace<KernelFallbackCompatRequestState>(
110       device_manager, builder->id(), runner_table, resource_array,
111       user_intra_op_threadpool, model_metadata, pflr);
112 
113   return Status::OK();
114 }
115 
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,OpKernelRunnerTable * runner_table,tensorflow::EagerContext * eager_context,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata)116 Status SetUpKernelFallbackCompatRequestContext(
117     tfrt::RequestContextBuilder* builder, OpKernelRunnerTable* runner_table,
118     tensorflow::EagerContext* eager_context,
119     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
120     const absl::optional<tfrt::ModelMetadata>& model_metadata) {
121   auto* resource_array =
122       builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
123           kFallbackResourceArray);
124 
125   if (runner_table == nullptr)
126     runner_table =
127         builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
128             kOpKernelRunnerTableResourceName);
129 
130   auto step_id = builder->id();
131 
132   auto& fallback_request_state =
133       builder->context_data().emplace<KernelFallbackCompatRequestState>(
134           eager_context->local_device_mgr(), step_id,
135           tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
136               eager_context->StepContainer()},
137           eager_context->GetCollectiveExecutorHandle(),
138           tensorflow::core::RefCountPtr<tensorflow::Rendezvous>(
139               eager_context->RendezvousCreator()(step_id)),
140           runner_table, resource_array, user_intra_op_threadpool,
141           model_metadata, eager_context->pflr());
142 
143   fallback_request_state.set_log_device_placement(
144       eager_context->LogDevicePlacement());
145 
146   return Status::OK();
147 }
148 
149 static llvm::Expected<gtl::InlinedVector<tensorflow::Tensor, 4>>
ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor * > arguments,const tfrt::ExecutionContext & exec_ctx)150 ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor*> arguments,
151                     const tfrt::ExecutionContext& exec_ctx) {
152   gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
153   input_tf_tensors.reserve(arguments.size());
154   for (tfrt::Tensor* argument : arguments) {
155     auto expected_tf_tensor = TFRTTensorToTFTensor(*argument, exec_ctx.host());
156     if (!expected_tf_tensor) {
157       return tfrt::MakeStringError(
158           tfrt::StrCat(expected_tf_tensor.takeError()));
159     }
160     input_tf_tensors.push_back(std::move(expected_tf_tensor.get()));
161   }
162 
163   return input_tf_tensors;
164 }
165 
ValidateInputTypes(tfrt::string_view op_name,const gtl::InlinedVector<tensorflow::Tensor,4> & input_tf_tensors,const DataTypeVector & input_types)166 static Status ValidateInputTypes(
167     tfrt::string_view op_name,
168     const gtl::InlinedVector<tensorflow::Tensor, 4>& input_tf_tensors,
169     const DataTypeVector& input_types) {
170   const size_t n_inputs = input_tf_tensors.size();
171 
172   if (input_types.size() != n_inputs) {
173     return tensorflow::errors::InvalidArgument("expected ", input_types.size(),
174                                                " inputs, got ", n_inputs);
175   }
176 
177   for (size_t i = 0; i < n_inputs; ++i) {
178     if (input_tf_tensors[i].dtype() != input_types[i]) {
179       return tensorflow::errors::InvalidArgument(
180           "cannot compute ", op_name.str(), " as input #", i, "(zero-based)",
181           " was expected to be a ", DataTypeString(input_types[i]),
182           " tensor but is a ", DataTypeString(input_tf_tensors[i].dtype()),
183           " tensor");
184     }
185   }
186 
187   return Status::OK();
188 }
189 
190 namespace {
191 
192 // OpKernelRunState keeps the states needed for per-kernel execution.
193 struct OpKernelRunState {
194   gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
195   gtl::InlinedVector<tensorflow::TensorValue, 4> input_tf_tensor_values;
196   OpKernelContext::Params params;
197 
198   OpKernelRunState() = default;
OpKernelRunStatetensorflow::tfd::__anon6ff0a4040211::OpKernelRunState199   OpKernelRunState(
200       const gtl::InlinedVector<tensorflow::TensorValue, 4>& tensor_values,
201       const OpKernelContext::Params& p) {
202     // `input_tf_tensor_values` contains the reference to all tensor used,
203     // while `input_tf_tensors` only contains those needs ownership so their
204     // sizes may not match. For this copy assignment, we conservatively copy all
205     // tensors.
206     input_tf_tensors.reserve(tensor_values.size());
207     for (const auto& tensor_value : tensor_values) {
208       input_tf_tensors.push_back(*tensor_value.tensor);
209     }
210     for (auto& tensor : input_tf_tensors) {
211       input_tf_tensor_values.emplace_back(&tensor);
212     }
213 
214     // Since `input_tf_tensor_values` and `params` contains pointers to
215     // `input_tf_tensors`, we need to change those pointers to the correct ones
216     // after copying.
217     params = p;
218     params.inputs = &input_tf_tensor_values;
219   }
220 
221   OpKernelRunState(const OpKernelRunState& other) = delete;
222   OpKernelRunState& operator=(const OpKernelRunState& other) = delete;
223 
224   ~OpKernelRunState() = default;
225 
SetUpParamstensorflow::tfd::__anon6ff0a4040211::OpKernelRunState226   void SetUpParams(
227       const OpKernelRunner& runner,
228       const KernelFallbackCompatRequestState& fallback_request_state) {
229     params.inputs = &input_tf_tensor_values;
230 
231     // Replace the thread pool device if the custom device is specified.
232     //
233     // The device handling is copied from below link:
234     // http://cs/?q=f:common_runtime%2Fexecutor.cc:692%20package:piper&rcl=351575626
235     if (auto* custom_device = fallback_request_state.custom_device()) {
236       params.device = custom_device;
237     } else {
238       params.device = runner.device();
239     }
240 
241     params.op_kernel = runner.op_kernel();
242     // Still use original device's resource_manager.
243     params.resource_manager = runner.resource_manager();
244     params.input_alloc_attrs = &runner.input_alloc_attrs();
245     params.output_attr_array = runner.output_alloc_attrs().data();
246     params.step_container = fallback_request_state.step_container();
247     // Following two parameters are used to support executing tf.data via
248     // fallback.
249     params.function_library = runner.function_library_runtime();
250     params.runner = fallback_request_state.runner();
251     params.collective_executor = fallback_request_state.collective_executor();
252     params.rendezvous = fallback_request_state.rendezvous();
253     params.session_metadata = &fallback_request_state.session_metadata();
254     params.cancellation_manager = fallback_request_state.cancellation_manager();
255   }
256 };
257 
258 // Keep states needed by kernel execution in a thread local storage to avoid
259 // repeated reallocation and destruction of them.
GetThreadLocalOpKernelRunState()260 OpKernelRunState& GetThreadLocalOpKernelRunState() {
261   thread_local OpKernelRunState run_state;
262   return run_state;
263 }
264 
265 }  // namespace
266 
267 // Execute a tensorflow::OpKernel Asynchronously. `kernel_runner` and
268 // `input_tf_tensors` are expected to be alive during the call to this function.
269 // Set result AsyncValues in `results` and return a Chain that indicates the
270 // execution completion of error otherwise.
271 template <typename TensorType>
KernelFallbackExecuteCompatAsyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)272 static void KernelFallbackExecuteCompatAsyncInternal(
273     const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
274     const OpKernelRunner& kernel_runner,
275     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
276     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
277   auto chain =
278       tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>(exec_ctx.host());
279   if (op_chain) *op_chain = chain.CopyRef();
280 
281   // Allocate unconstructed result tensors and set them in the output `results`.
282   llvm::SmallVector<AsyncValueRef<TensorType>, 4> result_refs;
283   result_refs.reserve(results.size());
284   for (auto& result : results) {
285     result_refs.emplace_back(
286         tfrt::MakeUnconstructedAsyncValueRef<TensorType>(exec_ctx.host()));
287     result = result_refs.back().CopyRef();
288   }
289 
290   struct AsyncState {
291     explicit AsyncState(const OpKernelRunState& rs, int num_outputs)
292         : run_state(rs.input_tf_tensor_values, rs.params),
293           context(&run_state.params, num_outputs) {}
294 
295     OpKernelRunState run_state;
296     OpKernelContext context;
297 
298     tfrt::AsyncValueRef<tfrt::Chain> chain;
299     llvm::SmallVector<tfrt::AsyncValueRef<TensorType>, 4> result_refs;
300   };
301 
302   DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
303   auto async_state = std::make_shared<AsyncState>(*run_state, results.size());
304   async_state->chain = chain.CopyRef();
305   async_state->result_refs = std::move(result_refs);
306 
307   auto* context_ptr = &async_state->context;
308 
309   auto done_callback = [async_state = std::move(async_state), exec_ctx]() {
310     auto& context = async_state->context;
311 
312     if (!context.status().ok()) {
313       auto diag = tfrt::EmitError(
314           exec_ctx,
315           {tfrt::StrCat("error running kernel fallback kernel ",
316                         context.op_kernel().name(), ": ",
317                         context.status().error_message())},
318           tfrt::ConvertTfErrorCodeToTfrtErrorCode(context.status()));
319       for (auto& result : async_state->result_refs) result.SetError(diag);
320       async_state->chain.SetError(diag);
321       return;
322     }
323 
324     // Set payload and mark async values available in TFRT's thread.
325     tfrt::EnqueueWork(exec_ctx, [async_state = std::move(async_state)]() {
326       auto& context = async_state->context;
327       for (int i = 0; i < context.num_outputs(); ++i) {
328         async_state->result_refs[i].emplace(
329             std::move(*context.mutable_output(i)));
330       }
331       async_state->chain.emplace();
332     });
333   };
334 
335   kernel_runner.RunAsync(context_ptr, std::move(done_callback));
336 }
337 
338 // Execute a tensorflow::OpKernel synchronously. `kernel_runner` and
339 // `input_tf_tensors` are expected to be alive during the call to this function.
340 // Set result AsyncValues in `results` and return OK status on successfully
341 // finishing the execution. TensorType is expected to be convert-constructible
342 // from tensorflow::Tensor.
343 template <typename TensorType>
KernelFallbackExecuteCompatSyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)344 static void KernelFallbackExecuteCompatSyncInternal(
345     const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
346     const OpKernelRunner& kernel_runner,
347     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
348     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
349   DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
350   OpKernelContext context(&run_state->params, results.size());
351   kernel_runner.Run(&context);
352 
353   if (!context.status().ok()) {
354     KernelFallbackEmitError(exec_ctx, kernel_runner.op_kernel()->name(),
355                             op_chain, results, context.status());
356     return;
357   }
358 
359   for (int i = 0; i < context.num_outputs(); ++i) {
360     results[i] = tfrt::MakeAvailableAsyncValueRef<TensorType>(
361         std::move(*context.mutable_output(i)));
362   }
363 
364   if (op_chain) *op_chain = tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
365 }
366 
PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef & attrs)367 static std::string PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef& attrs) {
368   std::string str;
369   llvm::raw_string_ostream ss(str);
370   attrs.Print(ss);
371   ss.flush();
372   return str;
373 }
374 
KernelFallbackExecuteCompatCoreRuntimeDispatch(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::string_view device_name,llvm::ArrayRef<tfrt::Tensor * > arguments,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tfrt::OpAttrsRef & attrs)375 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackExecuteCompatCoreRuntimeDispatch(
376     const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
377     tfrt::string_view device_name, llvm::ArrayRef<tfrt::Tensor*> arguments,
378     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
379     const tfrt::OpAttrsRef& attrs) {
380   auto op_chain = tfrt::GetReadyChain(exec_ctx.host());
381   tensorflow::Status status;
382 
383   const auto* fallback_request_state =
384       exec_ctx.request_ctx()
385           ->GetDataIfExists<KernelFallbackCompatRequestState>();
386   if (!fallback_request_state) {
387     status = tensorflow::errors::NotFound(
388         "KernelFallbackCompatRequestState not found in RequestContext.");
389     KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
390     return op_chain;
391   }
392 
393   DCHECK(exec_ctx.location());
394 
395   DCHECK(exec_ctx.request_ctx()->resource_context());
396   auto* runner_cache = exec_ctx.request_ctx()
397                            ->resource_context()
398                            ->GetOrCreateResource<OpKernelRunnerCache>(
399                                kOpKernelRunnerCacheResourceName);
400 
401   auto kernel_runner_or_status = runner_cache->GetOrCreate(
402       exec_ctx.location(), ToAbslStringView(op_name),
403       ToAbslStringView(device_name), arguments.size(),
404       [&attrs, host = exec_ctx.host()](
405           tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
406         VLOG(1) << "KernelFallbackExecuteCompat creating op from OpAttrs: "
407                 << PrintTfrtOpAttrsToString(attrs);
408         return FillAttrValueMap(attrs, host, attr_value_map);
409       },
410       *fallback_request_state);
411 
412   if (!kernel_runner_or_status.ok()) {
413     KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results,
414                             kernel_runner_or_status.status());
415     return op_chain;
416   }
417 
418   auto expected_input_tf_tensors = ConvertInputTensors(arguments, exec_ctx);
419   if (!expected_input_tf_tensors) {
420     status = tensorflow::errors::Internal(
421         tfrt::StrCat(expected_input_tf_tensors.takeError()));
422     KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
423     return op_chain;
424   }
425 
426   auto& kernel_runner = kernel_runner_or_status.ValueOrDie();
427 
428   auto& run_state = GetThreadLocalOpKernelRunState();
429   auto clean_up_inputs =
430       gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
431 
432   auto& input_tf_tensors = run_state.input_tf_tensors;
433   input_tf_tensors = std::move(expected_input_tf_tensors.get());
434 
435   // Check if input tensor dtypes are valid.
436   status = ValidateInputTypes(op_name, input_tf_tensors,
437                               kernel_runner->op_kernel()->input_types());
438 
439   // TODO(b/176997538): Skip checking dtypes for tf._BatchFunctionFallback op
440   // due to b/176997538. Remove the skipping once the SavedModel lowering
441   // problem is fixed.
442   if (!status.ok() && !op_name.equals("_BatchFunctionFallback")) {
443     KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
444     return op_chain;
445   }
446 
447   auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
448   input_tf_tensor_values.resize(input_tf_tensors.size());
449   for (int i = 0; i < input_tf_tensors.size(); ++i) {
450     input_tf_tensor_values[i].tensor = &input_tf_tensors[i];
451   }
452 
453   run_state.SetUpParams(*kernel_runner, *fallback_request_state);
454 
455   // TODO(b/166705169): Figure out how to properly fallback GPU kernels.
456   if (kernel_runner->IsAsync()) {
457     KernelFallbackExecuteCompatAsyncInternal<KernelFallbackTensor>(
458         exec_ctx, &run_state, *kernel_runner, &op_chain, results);
459   } else {
460     KernelFallbackExecuteCompatSyncInternal<KernelFallbackTensor>(
461         exec_ctx, &run_state, *kernel_runner, &op_chain, results);
462   }
463 
464   return op_chain;
465 }
466 
KernelFallbackSyncExecuteCompat(const tfrt::ExecutionContext & exec_ctx,absl::string_view op_name,absl::string_view device_name,tfrt::SyncKernelFrame * frame,const tfrt::OpAttrsRef & attrs)467 Status KernelFallbackSyncExecuteCompat(const tfrt::ExecutionContext& exec_ctx,
468                                        absl::string_view op_name,
469                                        absl::string_view device_name,
470                                        tfrt::SyncKernelFrame* frame,
471                                        const tfrt::OpAttrsRef& attrs) {
472   auto* fallback_request_state =
473       exec_ctx.request_ctx()
474           ->GetDataIfExists<KernelFallbackCompatRequestState>();
475   if (!fallback_request_state) {
476     return tensorflow::errors::Internal(
477         "KernelFallbackCompatRequestState not found in RequestContext.");
478   }
479 
480   DCHECK(exec_ctx.request_ctx()->resource_context());
481   auto* runner_cache = exec_ctx.request_ctx()
482                            ->resource_context()
483                            ->GetOrCreateResource<OpKernelRunnerCache>(
484                                kOpKernelRunnerCacheResourceName);
485 
486   TF_ASSIGN_OR_RETURN(
487       auto kernel_runner,
488       runner_cache->GetOrCreate(
489           exec_ctx.location(), op_name, device_name, frame->GetNumArgs(),
490           [&attrs, host = exec_ctx.host()](
491               tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
492             VLOG(1) << "KernelFallbackExecuteCompat creating op from OpAttrs: "
493                     << PrintTfrtOpAttrsToString(attrs);
494             return FillAttrValueMap(attrs, host, attr_value_map);
495           },
496           *fallback_request_state));
497 
498   gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
499   input_tf_tensors.reserve(frame->GetNumArgs());
500   for (int i = 0; i < frame->GetNumArgs(); ++i) {
501     auto& tensor = frame->GetArgAt<tensorflow::Tensor>(i);
502     input_tf_tensors.push_back(tensor);
503   }
504 
505   // Check if input tensor dtypes are valid.
506   TF_RETURN_IF_ERROR(ValidateInputTypes(
507       tfrt::string_view(op_name.data(), op_name.size()), input_tf_tensors,
508       kernel_runner->op_kernel()->input_types()));
509 
510   AsyncOpKernel* async = kernel_runner->op_kernel()->AsAsync();
511   if (async) {
512     LOG_EVERY_N_SEC(WARNING, 60)
513         << "Async kernels are being executed in sync mode, which could affect "
514            "performance. Consider async execution instead.";
515   }
516 
517   // TODO(b/166705169): Figure out how to properly fallback GPU kernels.
518   auto& run_state = GetThreadLocalOpKernelRunState();
519   auto clean_up_inputs =
520       gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
521 
522   run_state.input_tf_tensors = std::move(input_tf_tensors);
523 
524   auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
525   input_tf_tensor_values.resize(run_state.input_tf_tensors.size());
526   for (int i = 0; i < run_state.input_tf_tensors.size(); ++i) {
527     input_tf_tensor_values[i].tensor = &run_state.input_tf_tensors[i];
528   }
529 
530   run_state.SetUpParams(*kernel_runner, *fallback_request_state);
531 
532   OpKernelContext context(&run_state.params);
533   kernel_runner->Run(&context);
534 
535   if (!context.status().ok()) return context.status();
536 
537   DCHECK_EQ(context.num_outputs(), frame->GetNumResults());
538   for (int i = 0; i < context.num_outputs(); ++i) {
539     *frame->GetResultAt(i) = tfrt::Value(std::move(*context.mutable_output(i)));
540   }
541   return Status::OK();
542 }
543 
GetTfDevice(const tfrt::ExecutionContext & exec_ctx,const tfrt::Device & device)544 llvm::Expected<Device*> GetTfDevice(const tfrt::ExecutionContext& exec_ctx,
545                                     const tfrt::Device& device) {
546   auto* fallback_request_state =
547       exec_ctx.request_ctx()
548           ->GetDataIfExists<KernelFallbackCompatRequestState>();
549   if (!fallback_request_state) {
550     return tfrt::MakeStringError(
551         "KernelFallbackCompatRequestState not found in RequestContext.");
552   }
553   Device* tf_device;
554   Status s = fallback_request_state->device_manager().LookupDevice(
555       device.name().data(), &tf_device);
556   if (!s.ok()) {
557     return tfrt::MakeStringError(s.error_message());
558   }
559   return tf_device;
560 }
561 
StripTfPrefix(tfrt::string_view op_name)562 static absl::string_view StripTfPrefix(tfrt::string_view op_name) {
563   return absl::StripPrefix(ToAbslStringView(op_name), "tf.");
564 }
565 
566 // Generate metadata for an execution op event
GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue * > args,const tfrt::ExecutionContext & exec_ctx,const OpKernelRunner & kernel_runner)567 std::string GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue*> args,
568                                const tfrt::ExecutionContext& exec_ctx,
569                                const OpKernelRunner& kernel_runner) {
570   auto request_id = exec_ctx.request_ctx()->id();
571   auto current_tracing_level = tfrt::tracing::GetCurrentTracingLevel();
572 
573   if (current_tracing_level == tfrt::tracing::TracingLevel::Default) {
574     return profiler::TraceMeEncode({{"id", request_id}});
575   }
576 
577   // Get Long Name
578   auto debug_info = exec_ctx.location().GetDebugInfo();
579   auto long_name = debug_info.hasValue() ? debug_info.getValue().info : "";
580 
581   if (current_tracing_level == tfrt::tracing::TracingLevel::Verbose) {
582     return profiler::TraceMeEncode(
583         {{"id", request_id}, {"long_name", ToAbslStringView(long_name)}});
584   }
585 
586   // Get Input Tensors
587   std::string input_string;
588   llvm::raw_string_ostream input_string_stream(input_string);
589 
590   for (size_t i = 0; i < args.size(); ++i) {
591     const auto& tensor = args[i]->get<Tensor>();
592     input_string_stream << DataTypeString(tensor.dtype())
593                         << tensor.shape().DebugString();
594     input_string_stream << ";";
595   }
596 
597   // Get Attributes
598   std::string attr_string;
599   llvm::raw_string_ostream attr_string_stream(attr_string);
600 
601   for (const auto& entry : kernel_runner.op_kernel()->def().attr()) {
602     attr_string_stream << entry.first << ": {" << entry.second.DebugString();
603     if (!attr_string.empty() && attr_string[attr_string.size() - 1] == '\n') {
604       attr_string[attr_string.size() - 1] = '}';
605     }
606     attr_string_stream << ";\n";
607   }
608 
609   return profiler::TraceMeEncode({
610       {"id", request_id},
611       {"long_name", ToAbslStringView(long_name)},
612       {"inputs", input_string},
613       {"attributes", attr_string},
614   });
615 }
616 
617 namespace {
618 
619 class FallbackKernelAttributeFrame {
620  public:
FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame * frame)621   explicit FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame* frame)
622       : frame_(frame) {
623     DCHECK(frame_);
624   }
625 
device() const626   tfrt::StringAttr device() const {
627     return tfrt::StringAttr(frame_->GetAttribute(kDeviceAttrPosition));
628   }
629 
op_attr() const630   tfrt::AggregateAttr op_attr() const {
631     return tfrt::AggregateAttr(frame_->GetAttribute(kOpAttrPosition));
632   }
633 
op_func_attr() const634   tfrt::AggregateAttr op_func_attr() const {
635     return tfrt::AggregateAttr(frame_->GetAttribute(kOpFuncAttrPosition));
636   }
637 
op_key() const638   tfrt::I64Attr op_key() const {
639     return tfrt::I64Attr(frame_->GetAttribute(kOpKeyAttrPosition));
640   }
641 
op_name() const642   tfrt::StringAttr op_name() const {
643     return tfrt::StringAttr(frame_->GetAttribute(kOpNameAttrPosition));
644   }
645 
646  private:
647   static constexpr int kDeviceAttrPosition = 0;
648   static constexpr int kOpAttrPosition = 1;
649   static constexpr int kOpFuncAttrPosition = 2;
650   static constexpr int kOpKeyAttrPosition = 3;
651   static constexpr int kOpNameAttrPosition = 4;
652 
653   tfrt::AsyncKernelFrame* frame_ = nullptr;
654 };
655 
656 // The BEF kernel for kernel fallback compat mode. The arguments and results are
657 // expected to tensorflow::tfrt_stub::FallbackTensor.
KernelFallbackExecuteOp(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx)658 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOp(
659     llvm::ArrayRef<tfrt::AsyncValue*> args,
660     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
661     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
662     const FallbackKernelAttributeFrame& frame,
663     const tfrt::ExecutionContext& exec_ctx) {
664   tensorflow::profiler::TraceMe trace_me(
665       [&]() { return ToAbslStringView(frame.op_name().GetValue()); });
666 
667   const auto* fallback_request_state =
668       exec_ctx.request_ctx()
669           ->GetDataIfExists<KernelFallbackCompatRequestState>();
670   if (!fallback_request_state) {
671     KernelFallbackEmitError(
672         exec_ctx, frame.op_name().GetValue(), op_chain, results,
673         tensorflow::errors::NotFound(
674             "KernelFallbackCompatRequestState not found in RequestContext."));
675     return;
676   }
677 
678   auto* runner_table = fallback_request_state->runner_table();
679   DCHECK(runner_table);
680 
681   auto* kernel_runner = runner_table->Get(frame.op_key().GetValue());
682   DCHECK(kernel_runner);
683   DCHECK_EQ(kernel_runner->op_kernel()->name(),
684             StripTfPrefix(frame.op_name().GetValue()));
685 
686   trace_me.AppendMetadata(
687       [&]() { return GetTracingMetadata(args, exec_ctx, *kernel_runner); });
688 
689   if (fallback_request_state->log_device_placement() || VLOG_IS_ON(1)) {
690     string msg =
691         strings::StrCat("Executing op ", frame.op_name().GetValue().str(),
692                         " in device ", frame.device().GetValue().str());
693     if (!logging::LogToListeners(msg)) {
694       LOG(INFO) << msg;
695     }
696   }
697 
698   auto& run_state = GetThreadLocalOpKernelRunState();
699   auto clean_up_inputs =
700       gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
701 
702   // Prepare the input tensors.
703   auto& input_tf_tensors = run_state.input_tf_tensors;
704   auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
705   DCHECK(input_tf_tensors.empty());
706   input_tf_tensor_values.resize(args.size());
707   for (int i = 0; i < args.size(); ++i) {
708     auto* arg = args[i];
709     auto& fallback_tensor = arg->get<tensorflow::tfrt_stub::FallbackTensor>();
710     // If the argument is immutable or unique, we can just keep the reference
711     // without copying that invovles expensive atomic reference counting. And if
712     // the argument is unique but mutable, then tensorflow optimizations like
713     // buffer forwarding can be utilized. Otherwise, we conservatively copy the
714     // tensor.
715     if (!fallback_tensor.is_immutable() && !arg->IsUnique()) {
716       input_tf_tensors.push_back(fallback_tensor.tensor());
717     }
718     input_tf_tensor_values[i].tensor = &fallback_tensor.tensor();
719   }
720 
721   run_state.SetUpParams(*kernel_runner, *fallback_request_state);
722 
723   if (kernel_runner->IsAsync()) {
724     KernelFallbackExecuteCompatAsyncInternal<
725         tensorflow::tfrt_stub::FallbackTensor>(
726         exec_ctx, &run_state, *kernel_runner, op_chain, results);
727   } else {
728     KernelFallbackExecuteCompatSyncInternal<
729         tensorflow::tfrt_stub::FallbackTensor>(
730         exec_ctx, &run_state, *kernel_runner, op_chain, results);
731   }
732 }
733 
734 // The BEF kernel for creating tensorflow::OpKernel to be used in kernel
735 // fallback compat mode.
KernelFallbackCreateOp(const tfrt::Chain & in_ch,tfrt::StringAttr device,tfrt::I64Attr num_args,tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tfrt::I64Attr op_key,tfrt::StringAttr op_name_attr,const tfrt::ExecutionContext & exec_ctx)736 llvm::Expected<tfrt::Chain> KernelFallbackCreateOp(
737     const tfrt::Chain& in_ch, tfrt::StringAttr device, tfrt::I64Attr num_args,
738     tfrt::AggregateAttr op_attr_array, tfrt::AggregateAttr op_func_attr_array,
739     tfrt::I64Attr op_key, tfrt::StringAttr op_name_attr,
740     const tfrt::ExecutionContext& exec_ctx) {
741   const auto* fallback_request_state =
742       exec_ctx.request_ctx()
743           ->GetDataIfExists<KernelFallbackCompatRequestState>();
744   if (!fallback_request_state) {
745     return tfrt::MakeStringError(
746         "KernelFallbackCompatRequestState not found in RequestContext.");
747   }
748 
749   auto* runner_table = fallback_request_state->runner_table();
750   DCHECK(runner_table);
751 
752   auto attr_builder =
753       [op_attr_array, op_func_attr_array](
754           tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
755     auto status =
756         SetUpAttrValueMap(op_attr_array, op_func_attr_array, attr_value_map);
757 
758     if (!status.ok()) return tfrt::MakeStringError(status.ToString());
759     return llvm::Error::success();
760   };
761 
762   auto op_name = StripTfPrefix(op_name_attr.GetValue());
763 
764   auto statusor_runner = OpKernelRunner::Create(
765       op_name, ToAbslStringView(device.GetValue()), num_args.GetValue(),
766       attr_builder, *fallback_request_state);
767   if (!statusor_runner.ok())
768     return tfrt::MakeStringError(statusor_runner.status().ToString());
769 
770   if (!runner_table->Insert(op_key.GetValue(),
771                             std::move(statusor_runner).ValueOrDie())) {
772     return tfrt::MakeStringError(
773         absl::StrCat("KernelFallbackCreateOp: OpKernelRunner already exists: ",
774                      op_name_attr.GetValue().str()));
775   }
776 
777   return tfrt::Chain();
778 }
779 
780 // FallbackSetResource is the fallback kernel that sets the tensor value in the
781 // fallback's resource array.
FallbackSetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::StringAttr device,tfrt::I64Attr index_attr,const tfrt::ExecutionContext & exec_ctx)782 llvm::Expected<tfrt::Chain> FallbackSetResource(
783     tfrt::Argument<tfrt::Chain> in_ch,
784     tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
785     tfrt::StringAttr device, tfrt::I64Attr index_attr,
786     const tfrt::ExecutionContext& exec_ctx) {
787   const auto* fallback_request_state =
788       exec_ctx.request_ctx()
789           ->GetDataIfExists<KernelFallbackCompatRequestState>();
790   if (!fallback_request_state) {
791     return tfrt::MakeStringError(
792         "KernelFallbackCompatRequestState not found in RequestContext.");
793   }
794 
795   auto* resource_array = fallback_request_state->resource_array();
796   DCHECK(resource_array);
797 
798   int64_t index = index_attr.GetValue();
799 
800   // Setting the resource tensor to be immutable, so that we don't need
801   // reference counting on it and that it cannot be buffer-forwarded.
802   resource_array->SetResource(
803       index,
804       tensorflow::tfrt_stub::ImmutableTensor::Create(arg.get().tensor()));
805 
806   return tfrt::Chain();
807 }
808 
809 // FallbackGetResource is the fallback kernel that retrieves the tensor value in
810 // the fallback's resource array.
FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Result<tfrt::Chain> out_ch,tfrt::RemainingResults results,tfrt::StringAttr device,tfrt::ArrayAttr indices_attr,const tfrt::ExecutionContext & exec_ctx)811 void FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,
812                          tfrt::Result<tfrt::Chain> out_ch,
813                          tfrt::RemainingResults results,
814                          tfrt::StringAttr device, tfrt::ArrayAttr indices_attr,
815                          const tfrt::ExecutionContext& exec_ctx) {
816   tensorflow::profiler::TraceMe trace_me("tfrt_fallback_async.get_resource");
817   trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
818     return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
819   });
820 
821   const auto* fallback_request_state =
822       exec_ctx.request_ctx()
823           ->GetDataIfExists<KernelFallbackCompatRequestState>();
824   if (!fallback_request_state) {
825     tfrt::RCReference<tfrt::AsyncValue> error = tfrt::EmitErrorAsync(
826         exec_ctx,
827         "KernelFallbackCompatRequestState not found in RequestContext.");
828     out_ch.Set(std::move(error));
829     return;
830   }
831 
832   auto* resource_array = fallback_request_state->resource_array();
833   DCHECK(resource_array);
834 
835   llvm::ArrayRef<int64_t> indices = indices_attr.GetValue<int64_t>();
836 
837   for (int i = 0; i < indices.size(); ++i) {
838     results[i] = tfrt::FormRef(resource_array->GetResource(indices[i]));
839   }
840 
841   out_ch.Set(in_ch);
842 }
843 
844 // The implementation of tfrt_fallback_async.executeop kernel. It executes a
845 // non-side-effecting TF op with the name of `op_name` in fallback. All relevant
846 // TF attributes are passed in `op_attr_array`.
FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame * frame)847 void FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame* frame) {
848   FallbackKernelAttributeFrame attr_frame(frame);
849 #ifndef NDEBUG
850   frame->GetExecutionContext()
851       .host()
852       ->GetOrCreateSharedContext<OpLogger>()
853       .LogOp(attr_frame.op_name().GetValue());
854 #endif
855   KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
856                           /*op_chain=*/nullptr, attr_frame,
857                           frame->GetExecutionContext());
858 }
859 
860 // The implementation of tfrt_fallback_async.executeop.seq kernel. It executes a
861 // side-effecting TF op with the name of `op_name` in fallback. All relevant
862 // TF attributes are passed in `op_attr_array`. `in_op_chain` and `out_op_chain`
863 // are used for side-effect visibility.
FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame * frame)864 void FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame* frame) {
865   auto all_args = frame->GetArguments();
866   DCHECK_GT(all_args.size(), 0);
867   tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(all_args[0]));
868   llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
869 
870   auto all_results = frame->GetResults();
871   DCHECK_GT(all_results.size(), 0);
872   auto& out_op_chain = all_results[0];
873   llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results =
874       all_results.drop_front();
875 
876   KernelFallbackExecuteOp(args, results, &op_chain,
877                           FallbackKernelAttributeFrame(frame),
878                           frame->GetExecutionContext());
879   out_op_chain = std::move(op_chain);
880 }
881 
FallbackCopyTensorIfSmall(tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::RemainingResults results)882 void FallbackCopyTensorIfSmall(
883     tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
884     tfrt::RemainingResults results) {
885   const auto& fallback_tensor = arg.get();
886   const auto& tensor = fallback_tensor.tensor();
887 
888   if (!fallback_tensor.is_immutable()) {
889     // Create a new TensorBuffer which contains a new atomic counter for each
890     // result, to avoid downstream threads contending the original atomic
891     // counter.
892     for (int i = 0; i < results.size(); ++i) {
893       auto immutable_tensor =
894           tensorflow::tfrt_stub::ImmutableTensor::Create(tensor);
895       results[i] = tfrt::MakeAvailableAsyncValueRef<
896           tensorflow::tfrt_stub::FallbackTensor>(
897           std::move(immutable_tensor.tensor()));
898     }
899   } else {
900     // For immutable tensors, we just need to copy the pointer. Note that we
901     // still create a new AsyncValue in this case, to avoid atomic contention on
902     // AsyncValue's refcount.
903     for (int i = 0; i < results.size(); ++i) {
904       results[i] = tfrt::MakeAvailableAsyncValueRef<
905           tensorflow::tfrt_stub::FallbackTensor>(fallback_tensor);
906     }
907   }
908 }
909 
ConstTensorProto(tfrt::StringAttr serialized_tensor_proto)910 llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstTensorProto(
911     tfrt::StringAttr serialized_tensor_proto) {
912   tensorflow::TensorProto tensor_proto;
913   if (!tensor_proto.ParseFromString(serialized_tensor_proto.GetValue().str())) {
914     return tfrt::MakeStringError("Failed to parse const tensor proto");
915   }
916 
917   tensorflow::Tensor tensor;
918   if (!tensor.FromProto(tensor_proto)) {
919     return tfrt::MakeStringError("Failed to create tensor from tensor proto: ",
920                                  tensor_proto.ShortDebugString());
921   }
922 
923   return tensorflow::tfrt_stub::FallbackTensor(std::move(tensor));
924 }
925 
RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry * registry)926 void RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry* registry) {
927   registry->AddKernel("tfrt_fallback_async.const_tensor_proto",
928                       TFRT_KERNEL(ConstTensorProto));
929   registry->AddKernel("tfrt_fallback_async.executeop", FallbackAsyncExecuteOp);
930   registry->AddKernel("tfrt_fallback_async.executeop.seq",
931                       FallbackAsyncExecuteOpSeq);
932   registry->AddKernel("tfrt_fallback_async.copy_if_small",
933                       TFRT_KERNEL(FallbackCopyTensorIfSmall));
934   registry->AddKernel("tfrt_fallback_async.createop",
935                       TFRT_KERNEL(KernelFallbackCreateOp));
936   registry->AddKernel("tfrt_fallback_async.set_resource",
937                       TFRT_KERNEL(FallbackSetResource));
938   registry->AddKernel("tfrt_fallback_async.get_resource",
939                       TFRT_KERNEL(FallbackGetResource));
940 }
941 
942 TFRT_STATIC_KERNEL_REGISTRATION(RegisterKernelFallbackCompatKernels);
943 
944 }  // namespace
945 }  // namespace tfd
946 }  // namespace tensorflow
947