1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
16
17 #include <optional>
18 #include <string>
19
20 #include "llvm/ADT/StringRef.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/framework/logging.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/threadpool_interface.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
31 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h"
32 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
33 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
34 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
35 #include "tensorflow/core/tfrt/utils/error_util.h"
36 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
37 #include "tensorflow/core/tfrt/utils/tensor_util.h"
38 #include "tfrt/core_runtime/execute_op_impl.h" // from @tf_runtime
39 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
40 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
41 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
42 #include "tfrt/host_context/chain.h" // from @tf_runtime
43 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
44 #include "tfrt/host_context/kernel_registry.h" // from @tf_runtime
45 #include "tfrt/host_context/sync_kernel_frame.h" // from @tf_runtime
46 #include "tfrt/support/error_util.h" // from @tf_runtime
47 #include "tfrt/support/forward_decls.h" // from @tf_runtime
48 #include "tfrt/support/pointer_util.h" // from @tf_runtime
49 #include "tfrt/support/string_util.h" // from @tf_runtime
50 #include "tfrt/tensor/tensor.h" // from @tf_runtime
51 #include "tfrt/tracing/tracing.h" // from @tf_runtime
52
53 namespace tensorflow {
54 namespace tfd {
55 namespace {
56
57 using ::tfrt::AsyncValue;
58 using ::tfrt::AsyncValueRef;
59 using ::tfrt::Chain;
60 using ::tfrt::OpAttrsRef;
61 using ::tfrt::RCReference;
62 using ::tfrt::SmallVector;
63 using ::tfrt::string_view;
64
65 constexpr char kOpKernelRunnerTableResourceName[] =
66 "OpKernelRunnerTableResourceName";
67
68 constexpr char kOpKernelRunnerCacheResourceName[] =
69 "OpKernelRunnerCacheResourceName";
70
71 constexpr char kFallbackResourceArray[] = "FallbackResourceArray";
72
KernelFallbackEmitError(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tensorflow::Status & status)73 void KernelFallbackEmitError(
74 const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
75 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
76 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
77 const tensorflow::Status& status) {
78 // Set all results to error, with the correct TFRT error code according to the
79 // error propagated from runtime fallback execution.
80 auto error =
81 EmitErrorAsync(exec_ctx,
82 tfrt::StrCat("error running kernel fallback kernel ",
83 op_name, ": ", status.error_message()),
84 tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
85 for (auto& result : results) result = error.CopyRef();
86 if (op_chain) *op_chain = std::move(error);
87 }
88
89 } // namespace
90
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,const tensorflow::DeviceMgr * device_manager,const tensorflow::ProcessFunctionLibraryRuntime * pflr,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata)91 Status SetUpKernelFallbackCompatRequestContext(
92 tfrt::RequestContextBuilder* builder,
93 const tensorflow::DeviceMgr* device_manager,
94 const tensorflow::ProcessFunctionLibraryRuntime* pflr,
95 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
96 const absl::optional<tfrt::ModelMetadata>& model_metadata) {
97 DCHECK(builder);
98 DCHECK(device_manager);
99 DCHECK(pflr);
100
101 auto* runner_table =
102 builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
103 kOpKernelRunnerTableResourceName);
104
105 auto* resource_array =
106 builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
107 kFallbackResourceArray);
108
109 builder->context_data().emplace<KernelFallbackCompatRequestState>(
110 device_manager, builder->id(), runner_table, resource_array,
111 user_intra_op_threadpool, model_metadata, pflr);
112
113 return Status::OK();
114 }
115
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,OpKernelRunnerTable * runner_table,tensorflow::EagerContext * eager_context,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata)116 Status SetUpKernelFallbackCompatRequestContext(
117 tfrt::RequestContextBuilder* builder, OpKernelRunnerTable* runner_table,
118 tensorflow::EagerContext* eager_context,
119 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
120 const absl::optional<tfrt::ModelMetadata>& model_metadata) {
121 auto* resource_array =
122 builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
123 kFallbackResourceArray);
124
125 if (runner_table == nullptr)
126 runner_table =
127 builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
128 kOpKernelRunnerTableResourceName);
129
130 auto step_id = builder->id();
131
132 auto& fallback_request_state =
133 builder->context_data().emplace<KernelFallbackCompatRequestState>(
134 eager_context->local_device_mgr(), step_id,
135 tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
136 eager_context->StepContainer()},
137 eager_context->GetCollectiveExecutorHandle(),
138 tensorflow::core::RefCountPtr<tensorflow::Rendezvous>(
139 eager_context->RendezvousCreator()(step_id)),
140 runner_table, resource_array, user_intra_op_threadpool,
141 model_metadata, eager_context->pflr());
142
143 fallback_request_state.set_log_device_placement(
144 eager_context->LogDevicePlacement());
145
146 return Status::OK();
147 }
148
149 static llvm::Expected<gtl::InlinedVector<tensorflow::Tensor, 4>>
ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor * > arguments,const tfrt::ExecutionContext & exec_ctx)150 ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor*> arguments,
151 const tfrt::ExecutionContext& exec_ctx) {
152 gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
153 input_tf_tensors.reserve(arguments.size());
154 for (tfrt::Tensor* argument : arguments) {
155 auto expected_tf_tensor = TFRTTensorToTFTensor(*argument, exec_ctx.host());
156 if (!expected_tf_tensor) {
157 return tfrt::MakeStringError(
158 tfrt::StrCat(expected_tf_tensor.takeError()));
159 }
160 input_tf_tensors.push_back(std::move(expected_tf_tensor.get()));
161 }
162
163 return input_tf_tensors;
164 }
165
ValidateInputTypes(tfrt::string_view op_name,const gtl::InlinedVector<tensorflow::Tensor,4> & input_tf_tensors,const DataTypeVector & input_types)166 static Status ValidateInputTypes(
167 tfrt::string_view op_name,
168 const gtl::InlinedVector<tensorflow::Tensor, 4>& input_tf_tensors,
169 const DataTypeVector& input_types) {
170 const size_t n_inputs = input_tf_tensors.size();
171
172 if (input_types.size() != n_inputs) {
173 return tensorflow::errors::InvalidArgument("expected ", input_types.size(),
174 " inputs, got ", n_inputs);
175 }
176
177 for (size_t i = 0; i < n_inputs; ++i) {
178 if (input_tf_tensors[i].dtype() != input_types[i]) {
179 return tensorflow::errors::InvalidArgument(
180 "cannot compute ", op_name.str(), " as input #", i, "(zero-based)",
181 " was expected to be a ", DataTypeString(input_types[i]),
182 " tensor but is a ", DataTypeString(input_tf_tensors[i].dtype()),
183 " tensor");
184 }
185 }
186
187 return Status::OK();
188 }
189
190 namespace {
191
192 // OpKernelRunState keeps the states needed for per-kernel execution.
193 struct OpKernelRunState {
194 gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
195 gtl::InlinedVector<tensorflow::TensorValue, 4> input_tf_tensor_values;
196 OpKernelContext::Params params;
197
198 OpKernelRunState() = default;
OpKernelRunStatetensorflow::tfd::__anon6ff0a4040211::OpKernelRunState199 OpKernelRunState(
200 const gtl::InlinedVector<tensorflow::TensorValue, 4>& tensor_values,
201 const OpKernelContext::Params& p) {
202 // `input_tf_tensor_values` contains the reference to all tensor used,
203 // while `input_tf_tensors` only contains those needs ownership so their
204 // sizes may not match. For this copy assignment, we conservatively copy all
205 // tensors.
206 input_tf_tensors.reserve(tensor_values.size());
207 for (const auto& tensor_value : tensor_values) {
208 input_tf_tensors.push_back(*tensor_value.tensor);
209 }
210 for (auto& tensor : input_tf_tensors) {
211 input_tf_tensor_values.emplace_back(&tensor);
212 }
213
214 // Since `input_tf_tensor_values` and `params` contains pointers to
215 // `input_tf_tensors`, we need to change those pointers to the correct ones
216 // after copying.
217 params = p;
218 params.inputs = &input_tf_tensor_values;
219 }
220
221 OpKernelRunState(const OpKernelRunState& other) = delete;
222 OpKernelRunState& operator=(const OpKernelRunState& other) = delete;
223
224 ~OpKernelRunState() = default;
225
SetUpParamstensorflow::tfd::__anon6ff0a4040211::OpKernelRunState226 void SetUpParams(
227 const OpKernelRunner& runner,
228 const KernelFallbackCompatRequestState& fallback_request_state) {
229 params.inputs = &input_tf_tensor_values;
230
231 // Replace the thread pool device if the custom device is specified.
232 //
233 // The device handling is copied from below link:
234 // http://cs/?q=f:common_runtime%2Fexecutor.cc:692%20package:piper&rcl=351575626
235 if (auto* custom_device = fallback_request_state.custom_device()) {
236 params.device = custom_device;
237 } else {
238 params.device = runner.device();
239 }
240
241 params.op_kernel = runner.op_kernel();
242 // Still use original device's resource_manager.
243 params.resource_manager = runner.resource_manager();
244 params.input_alloc_attrs = &runner.input_alloc_attrs();
245 params.output_attr_array = runner.output_alloc_attrs().data();
246 params.step_container = fallback_request_state.step_container();
247 // Following two parameters are used to support executing tf.data via
248 // fallback.
249 params.function_library = runner.function_library_runtime();
250 params.runner = fallback_request_state.runner();
251 params.collective_executor = fallback_request_state.collective_executor();
252 params.rendezvous = fallback_request_state.rendezvous();
253 params.session_metadata = &fallback_request_state.session_metadata();
254 params.cancellation_manager = fallback_request_state.cancellation_manager();
255 }
256 };
257
258 // Keep states needed by kernel execution in a thread local storage to avoid
259 // repeated reallocation and destruction of them.
GetThreadLocalOpKernelRunState()260 OpKernelRunState& GetThreadLocalOpKernelRunState() {
261 thread_local OpKernelRunState run_state;
262 return run_state;
263 }
264
265 } // namespace
266
267 // Execute a tensorflow::OpKernel Asynchronously. `kernel_runner` and
268 // `input_tf_tensors` are expected to be alive during the call to this function.
269 // Set result AsyncValues in `results` and return a Chain that indicates the
270 // execution completion of error otherwise.
271 template <typename TensorType>
KernelFallbackExecuteCompatAsyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)272 static void KernelFallbackExecuteCompatAsyncInternal(
273 const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
274 const OpKernelRunner& kernel_runner,
275 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
276 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
277 auto chain =
278 tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>(exec_ctx.host());
279 if (op_chain) *op_chain = chain.CopyRef();
280
281 // Allocate unconstructed result tensors and set them in the output `results`.
282 llvm::SmallVector<AsyncValueRef<TensorType>, 4> result_refs;
283 result_refs.reserve(results.size());
284 for (auto& result : results) {
285 result_refs.emplace_back(
286 tfrt::MakeUnconstructedAsyncValueRef<TensorType>(exec_ctx.host()));
287 result = result_refs.back().CopyRef();
288 }
289
290 struct AsyncState {
291 explicit AsyncState(const OpKernelRunState& rs, int num_outputs)
292 : run_state(rs.input_tf_tensor_values, rs.params),
293 context(&run_state.params, num_outputs) {}
294
295 OpKernelRunState run_state;
296 OpKernelContext context;
297
298 tfrt::AsyncValueRef<tfrt::Chain> chain;
299 llvm::SmallVector<tfrt::AsyncValueRef<TensorType>, 4> result_refs;
300 };
301
302 DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
303 auto async_state = std::make_shared<AsyncState>(*run_state, results.size());
304 async_state->chain = chain.CopyRef();
305 async_state->result_refs = std::move(result_refs);
306
307 auto* context_ptr = &async_state->context;
308
309 auto done_callback = [async_state = std::move(async_state), exec_ctx]() {
310 auto& context = async_state->context;
311
312 if (!context.status().ok()) {
313 auto diag = tfrt::EmitError(
314 exec_ctx,
315 {tfrt::StrCat("error running kernel fallback kernel ",
316 context.op_kernel().name(), ": ",
317 context.status().error_message())},
318 tfrt::ConvertTfErrorCodeToTfrtErrorCode(context.status()));
319 for (auto& result : async_state->result_refs) result.SetError(diag);
320 async_state->chain.SetError(diag);
321 return;
322 }
323
324 // Set payload and mark async values available in TFRT's thread.
325 tfrt::EnqueueWork(exec_ctx, [async_state = std::move(async_state)]() {
326 auto& context = async_state->context;
327 for (int i = 0; i < context.num_outputs(); ++i) {
328 async_state->result_refs[i].emplace(
329 std::move(*context.mutable_output(i)));
330 }
331 async_state->chain.emplace();
332 });
333 };
334
335 kernel_runner.RunAsync(context_ptr, std::move(done_callback));
336 }
337
338 // Execute a tensorflow::OpKernel synchronously. `kernel_runner` and
339 // `input_tf_tensors` are expected to be alive during the call to this function.
340 // Set result AsyncValues in `results` and return OK status on successfully
341 // finishing the execution. TensorType is expected to be convert-constructible
342 // from tensorflow::Tensor.
343 template <typename TensorType>
KernelFallbackExecuteCompatSyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)344 static void KernelFallbackExecuteCompatSyncInternal(
345 const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
346 const OpKernelRunner& kernel_runner,
347 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
348 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
349 DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
350 OpKernelContext context(&run_state->params, results.size());
351 kernel_runner.Run(&context);
352
353 if (!context.status().ok()) {
354 KernelFallbackEmitError(exec_ctx, kernel_runner.op_kernel()->name(),
355 op_chain, results, context.status());
356 return;
357 }
358
359 for (int i = 0; i < context.num_outputs(); ++i) {
360 results[i] = tfrt::MakeAvailableAsyncValueRef<TensorType>(
361 std::move(*context.mutable_output(i)));
362 }
363
364 if (op_chain) *op_chain = tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
365 }
366
PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef & attrs)367 static std::string PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef& attrs) {
368 std::string str;
369 llvm::raw_string_ostream ss(str);
370 attrs.Print(ss);
371 ss.flush();
372 return str;
373 }
374
KernelFallbackExecuteCompatCoreRuntimeDispatch(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::string_view device_name,llvm::ArrayRef<tfrt::Tensor * > arguments,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tfrt::OpAttrsRef & attrs)375 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackExecuteCompatCoreRuntimeDispatch(
376 const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
377 tfrt::string_view device_name, llvm::ArrayRef<tfrt::Tensor*> arguments,
378 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
379 const tfrt::OpAttrsRef& attrs) {
380 auto op_chain = tfrt::GetReadyChain(exec_ctx.host());
381 tensorflow::Status status;
382
383 const auto* fallback_request_state =
384 exec_ctx.request_ctx()
385 ->GetDataIfExists<KernelFallbackCompatRequestState>();
386 if (!fallback_request_state) {
387 status = tensorflow::errors::NotFound(
388 "KernelFallbackCompatRequestState not found in RequestContext.");
389 KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
390 return op_chain;
391 }
392
393 DCHECK(exec_ctx.location());
394
395 DCHECK(exec_ctx.request_ctx()->resource_context());
396 auto* runner_cache = exec_ctx.request_ctx()
397 ->resource_context()
398 ->GetOrCreateResource<OpKernelRunnerCache>(
399 kOpKernelRunnerCacheResourceName);
400
401 auto kernel_runner_or_status = runner_cache->GetOrCreate(
402 exec_ctx.location(), ToAbslStringView(op_name),
403 ToAbslStringView(device_name), arguments.size(),
404 [&attrs, host = exec_ctx.host()](
405 tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
406 VLOG(1) << "KernelFallbackExecuteCompat creating op from OpAttrs: "
407 << PrintTfrtOpAttrsToString(attrs);
408 return FillAttrValueMap(attrs, host, attr_value_map);
409 },
410 *fallback_request_state);
411
412 if (!kernel_runner_or_status.ok()) {
413 KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results,
414 kernel_runner_or_status.status());
415 return op_chain;
416 }
417
418 auto expected_input_tf_tensors = ConvertInputTensors(arguments, exec_ctx);
419 if (!expected_input_tf_tensors) {
420 status = tensorflow::errors::Internal(
421 tfrt::StrCat(expected_input_tf_tensors.takeError()));
422 KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
423 return op_chain;
424 }
425
426 auto& kernel_runner = kernel_runner_or_status.ValueOrDie();
427
428 auto& run_state = GetThreadLocalOpKernelRunState();
429 auto clean_up_inputs =
430 gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
431
432 auto& input_tf_tensors = run_state.input_tf_tensors;
433 input_tf_tensors = std::move(expected_input_tf_tensors.get());
434
435 // Check if input tensor dtypes are valid.
436 status = ValidateInputTypes(op_name, input_tf_tensors,
437 kernel_runner->op_kernel()->input_types());
438
439 // TODO(b/176997538): Skip checking dtypes for tf._BatchFunctionFallback op
440 // due to b/176997538. Remove the skipping once the SavedModel lowering
441 // problem is fixed.
442 if (!status.ok() && !op_name.equals("_BatchFunctionFallback")) {
443 KernelFallbackEmitError(exec_ctx, op_name, &op_chain, results, status);
444 return op_chain;
445 }
446
447 auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
448 input_tf_tensor_values.resize(input_tf_tensors.size());
449 for (int i = 0; i < input_tf_tensors.size(); ++i) {
450 input_tf_tensor_values[i].tensor = &input_tf_tensors[i];
451 }
452
453 run_state.SetUpParams(*kernel_runner, *fallback_request_state);
454
455 // TODO(b/166705169): Figure out how to properly fallback GPU kernels.
456 if (kernel_runner->IsAsync()) {
457 KernelFallbackExecuteCompatAsyncInternal<KernelFallbackTensor>(
458 exec_ctx, &run_state, *kernel_runner, &op_chain, results);
459 } else {
460 KernelFallbackExecuteCompatSyncInternal<KernelFallbackTensor>(
461 exec_ctx, &run_state, *kernel_runner, &op_chain, results);
462 }
463
464 return op_chain;
465 }
466
KernelFallbackSyncExecuteCompat(const tfrt::ExecutionContext & exec_ctx,absl::string_view op_name,absl::string_view device_name,tfrt::SyncKernelFrame * frame,const tfrt::OpAttrsRef & attrs)467 Status KernelFallbackSyncExecuteCompat(const tfrt::ExecutionContext& exec_ctx,
468 absl::string_view op_name,
469 absl::string_view device_name,
470 tfrt::SyncKernelFrame* frame,
471 const tfrt::OpAttrsRef& attrs) {
472 auto* fallback_request_state =
473 exec_ctx.request_ctx()
474 ->GetDataIfExists<KernelFallbackCompatRequestState>();
475 if (!fallback_request_state) {
476 return tensorflow::errors::Internal(
477 "KernelFallbackCompatRequestState not found in RequestContext.");
478 }
479
480 DCHECK(exec_ctx.request_ctx()->resource_context());
481 auto* runner_cache = exec_ctx.request_ctx()
482 ->resource_context()
483 ->GetOrCreateResource<OpKernelRunnerCache>(
484 kOpKernelRunnerCacheResourceName);
485
486 TF_ASSIGN_OR_RETURN(
487 auto kernel_runner,
488 runner_cache->GetOrCreate(
489 exec_ctx.location(), op_name, device_name, frame->GetNumArgs(),
490 [&attrs, host = exec_ctx.host()](
491 tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
492 VLOG(1) << "KernelFallbackExecuteCompat creating op from OpAttrs: "
493 << PrintTfrtOpAttrsToString(attrs);
494 return FillAttrValueMap(attrs, host, attr_value_map);
495 },
496 *fallback_request_state));
497
498 gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
499 input_tf_tensors.reserve(frame->GetNumArgs());
500 for (int i = 0; i < frame->GetNumArgs(); ++i) {
501 auto& tensor = frame->GetArgAt<tensorflow::Tensor>(i);
502 input_tf_tensors.push_back(tensor);
503 }
504
505 // Check if input tensor dtypes are valid.
506 TF_RETURN_IF_ERROR(ValidateInputTypes(
507 tfrt::string_view(op_name.data(), op_name.size()), input_tf_tensors,
508 kernel_runner->op_kernel()->input_types()));
509
510 AsyncOpKernel* async = kernel_runner->op_kernel()->AsAsync();
511 if (async) {
512 LOG_EVERY_N_SEC(WARNING, 60)
513 << "Async kernels are being executed in sync mode, which could affect "
514 "performance. Consider async execution instead.";
515 }
516
517 // TODO(b/166705169): Figure out how to properly fallback GPU kernels.
518 auto& run_state = GetThreadLocalOpKernelRunState();
519 auto clean_up_inputs =
520 gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
521
522 run_state.input_tf_tensors = std::move(input_tf_tensors);
523
524 auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
525 input_tf_tensor_values.resize(run_state.input_tf_tensors.size());
526 for (int i = 0; i < run_state.input_tf_tensors.size(); ++i) {
527 input_tf_tensor_values[i].tensor = &run_state.input_tf_tensors[i];
528 }
529
530 run_state.SetUpParams(*kernel_runner, *fallback_request_state);
531
532 OpKernelContext context(&run_state.params);
533 kernel_runner->Run(&context);
534
535 if (!context.status().ok()) return context.status();
536
537 DCHECK_EQ(context.num_outputs(), frame->GetNumResults());
538 for (int i = 0; i < context.num_outputs(); ++i) {
539 *frame->GetResultAt(i) = tfrt::Value(std::move(*context.mutable_output(i)));
540 }
541 return Status::OK();
542 }
543
GetTfDevice(const tfrt::ExecutionContext & exec_ctx,const tfrt::Device & device)544 llvm::Expected<Device*> GetTfDevice(const tfrt::ExecutionContext& exec_ctx,
545 const tfrt::Device& device) {
546 auto* fallback_request_state =
547 exec_ctx.request_ctx()
548 ->GetDataIfExists<KernelFallbackCompatRequestState>();
549 if (!fallback_request_state) {
550 return tfrt::MakeStringError(
551 "KernelFallbackCompatRequestState not found in RequestContext.");
552 }
553 Device* tf_device;
554 Status s = fallback_request_state->device_manager().LookupDevice(
555 device.name().data(), &tf_device);
556 if (!s.ok()) {
557 return tfrt::MakeStringError(s.error_message());
558 }
559 return tf_device;
560 }
561
StripTfPrefix(tfrt::string_view op_name)562 static absl::string_view StripTfPrefix(tfrt::string_view op_name) {
563 return absl::StripPrefix(ToAbslStringView(op_name), "tf.");
564 }
565
566 // Generate metadata for an execution op event
GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue * > args,const tfrt::ExecutionContext & exec_ctx,const OpKernelRunner & kernel_runner)567 std::string GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue*> args,
568 const tfrt::ExecutionContext& exec_ctx,
569 const OpKernelRunner& kernel_runner) {
570 auto request_id = exec_ctx.request_ctx()->id();
571 auto current_tracing_level = tfrt::tracing::GetCurrentTracingLevel();
572
573 if (current_tracing_level == tfrt::tracing::TracingLevel::Default) {
574 return profiler::TraceMeEncode({{"id", request_id}});
575 }
576
577 // Get Long Name
578 auto debug_info = exec_ctx.location().GetDebugInfo();
579 auto long_name = debug_info.hasValue() ? debug_info.getValue().info : "";
580
581 if (current_tracing_level == tfrt::tracing::TracingLevel::Verbose) {
582 return profiler::TraceMeEncode(
583 {{"id", request_id}, {"long_name", ToAbslStringView(long_name)}});
584 }
585
586 // Get Input Tensors
587 std::string input_string;
588 llvm::raw_string_ostream input_string_stream(input_string);
589
590 for (size_t i = 0; i < args.size(); ++i) {
591 const auto& tensor = args[i]->get<Tensor>();
592 input_string_stream << DataTypeString(tensor.dtype())
593 << tensor.shape().DebugString();
594 input_string_stream << ";";
595 }
596
597 // Get Attributes
598 std::string attr_string;
599 llvm::raw_string_ostream attr_string_stream(attr_string);
600
601 for (const auto& entry : kernel_runner.op_kernel()->def().attr()) {
602 attr_string_stream << entry.first << ": {" << entry.second.DebugString();
603 if (!attr_string.empty() && attr_string[attr_string.size() - 1] == '\n') {
604 attr_string[attr_string.size() - 1] = '}';
605 }
606 attr_string_stream << ";\n";
607 }
608
609 return profiler::TraceMeEncode({
610 {"id", request_id},
611 {"long_name", ToAbslStringView(long_name)},
612 {"inputs", input_string},
613 {"attributes", attr_string},
614 });
615 }
616
617 namespace {
618
619 class FallbackKernelAttributeFrame {
620 public:
FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame * frame)621 explicit FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame* frame)
622 : frame_(frame) {
623 DCHECK(frame_);
624 }
625
device() const626 tfrt::StringAttr device() const {
627 return tfrt::StringAttr(frame_->GetAttribute(kDeviceAttrPosition));
628 }
629
op_attr() const630 tfrt::AggregateAttr op_attr() const {
631 return tfrt::AggregateAttr(frame_->GetAttribute(kOpAttrPosition));
632 }
633
op_func_attr() const634 tfrt::AggregateAttr op_func_attr() const {
635 return tfrt::AggregateAttr(frame_->GetAttribute(kOpFuncAttrPosition));
636 }
637
op_key() const638 tfrt::I64Attr op_key() const {
639 return tfrt::I64Attr(frame_->GetAttribute(kOpKeyAttrPosition));
640 }
641
op_name() const642 tfrt::StringAttr op_name() const {
643 return tfrt::StringAttr(frame_->GetAttribute(kOpNameAttrPosition));
644 }
645
646 private:
647 static constexpr int kDeviceAttrPosition = 0;
648 static constexpr int kOpAttrPosition = 1;
649 static constexpr int kOpFuncAttrPosition = 2;
650 static constexpr int kOpKeyAttrPosition = 3;
651 static constexpr int kOpNameAttrPosition = 4;
652
653 tfrt::AsyncKernelFrame* frame_ = nullptr;
654 };
655
656 // The BEF kernel for kernel fallback compat mode. The arguments and results are
657 // expected to tensorflow::tfrt_stub::FallbackTensor.
KernelFallbackExecuteOp(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx)658 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOp(
659 llvm::ArrayRef<tfrt::AsyncValue*> args,
660 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
661 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
662 const FallbackKernelAttributeFrame& frame,
663 const tfrt::ExecutionContext& exec_ctx) {
664 tensorflow::profiler::TraceMe trace_me(
665 [&]() { return ToAbslStringView(frame.op_name().GetValue()); });
666
667 const auto* fallback_request_state =
668 exec_ctx.request_ctx()
669 ->GetDataIfExists<KernelFallbackCompatRequestState>();
670 if (!fallback_request_state) {
671 KernelFallbackEmitError(
672 exec_ctx, frame.op_name().GetValue(), op_chain, results,
673 tensorflow::errors::NotFound(
674 "KernelFallbackCompatRequestState not found in RequestContext."));
675 return;
676 }
677
678 auto* runner_table = fallback_request_state->runner_table();
679 DCHECK(runner_table);
680
681 auto* kernel_runner = runner_table->Get(frame.op_key().GetValue());
682 DCHECK(kernel_runner);
683 DCHECK_EQ(kernel_runner->op_kernel()->name(),
684 StripTfPrefix(frame.op_name().GetValue()));
685
686 trace_me.AppendMetadata(
687 [&]() { return GetTracingMetadata(args, exec_ctx, *kernel_runner); });
688
689 if (fallback_request_state->log_device_placement() || VLOG_IS_ON(1)) {
690 string msg =
691 strings::StrCat("Executing op ", frame.op_name().GetValue().str(),
692 " in device ", frame.device().GetValue().str());
693 if (!logging::LogToListeners(msg)) {
694 LOG(INFO) << msg;
695 }
696 }
697
698 auto& run_state = GetThreadLocalOpKernelRunState();
699 auto clean_up_inputs =
700 gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
701
702 // Prepare the input tensors.
703 auto& input_tf_tensors = run_state.input_tf_tensors;
704 auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
705 DCHECK(input_tf_tensors.empty());
706 input_tf_tensor_values.resize(args.size());
707 for (int i = 0; i < args.size(); ++i) {
708 auto* arg = args[i];
709 auto& fallback_tensor = arg->get<tensorflow::tfrt_stub::FallbackTensor>();
710 // If the argument is immutable or unique, we can just keep the reference
711 // without copying that invovles expensive atomic reference counting. And if
712 // the argument is unique but mutable, then tensorflow optimizations like
713 // buffer forwarding can be utilized. Otherwise, we conservatively copy the
714 // tensor.
715 if (!fallback_tensor.is_immutable() && !arg->IsUnique()) {
716 input_tf_tensors.push_back(fallback_tensor.tensor());
717 }
718 input_tf_tensor_values[i].tensor = &fallback_tensor.tensor();
719 }
720
721 run_state.SetUpParams(*kernel_runner, *fallback_request_state);
722
723 if (kernel_runner->IsAsync()) {
724 KernelFallbackExecuteCompatAsyncInternal<
725 tensorflow::tfrt_stub::FallbackTensor>(
726 exec_ctx, &run_state, *kernel_runner, op_chain, results);
727 } else {
728 KernelFallbackExecuteCompatSyncInternal<
729 tensorflow::tfrt_stub::FallbackTensor>(
730 exec_ctx, &run_state, *kernel_runner, op_chain, results);
731 }
732 }
733
734 // The BEF kernel for creating tensorflow::OpKernel to be used in kernel
735 // fallback compat mode.
KernelFallbackCreateOp(const tfrt::Chain & in_ch,tfrt::StringAttr device,tfrt::I64Attr num_args,tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tfrt::I64Attr op_key,tfrt::StringAttr op_name_attr,const tfrt::ExecutionContext & exec_ctx)736 llvm::Expected<tfrt::Chain> KernelFallbackCreateOp(
737 const tfrt::Chain& in_ch, tfrt::StringAttr device, tfrt::I64Attr num_args,
738 tfrt::AggregateAttr op_attr_array, tfrt::AggregateAttr op_func_attr_array,
739 tfrt::I64Attr op_key, tfrt::StringAttr op_name_attr,
740 const tfrt::ExecutionContext& exec_ctx) {
741 const auto* fallback_request_state =
742 exec_ctx.request_ctx()
743 ->GetDataIfExists<KernelFallbackCompatRequestState>();
744 if (!fallback_request_state) {
745 return tfrt::MakeStringError(
746 "KernelFallbackCompatRequestState not found in RequestContext.");
747 }
748
749 auto* runner_table = fallback_request_state->runner_table();
750 DCHECK(runner_table);
751
752 auto attr_builder =
753 [op_attr_array, op_func_attr_array](
754 tensorflow::AttrValueMap* attr_value_map) -> llvm::Error {
755 auto status =
756 SetUpAttrValueMap(op_attr_array, op_func_attr_array, attr_value_map);
757
758 if (!status.ok()) return tfrt::MakeStringError(status.ToString());
759 return llvm::Error::success();
760 };
761
762 auto op_name = StripTfPrefix(op_name_attr.GetValue());
763
764 auto statusor_runner = OpKernelRunner::Create(
765 op_name, ToAbslStringView(device.GetValue()), num_args.GetValue(),
766 attr_builder, *fallback_request_state);
767 if (!statusor_runner.ok())
768 return tfrt::MakeStringError(statusor_runner.status().ToString());
769
770 if (!runner_table->Insert(op_key.GetValue(),
771 std::move(statusor_runner).ValueOrDie())) {
772 return tfrt::MakeStringError(
773 absl::StrCat("KernelFallbackCreateOp: OpKernelRunner already exists: ",
774 op_name_attr.GetValue().str()));
775 }
776
777 return tfrt::Chain();
778 }
779
780 // FallbackSetResource is the fallback kernel that sets the tensor value in the
781 // fallback's resource array.
FallbackSetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::StringAttr device,tfrt::I64Attr index_attr,const tfrt::ExecutionContext & exec_ctx)782 llvm::Expected<tfrt::Chain> FallbackSetResource(
783 tfrt::Argument<tfrt::Chain> in_ch,
784 tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
785 tfrt::StringAttr device, tfrt::I64Attr index_attr,
786 const tfrt::ExecutionContext& exec_ctx) {
787 const auto* fallback_request_state =
788 exec_ctx.request_ctx()
789 ->GetDataIfExists<KernelFallbackCompatRequestState>();
790 if (!fallback_request_state) {
791 return tfrt::MakeStringError(
792 "KernelFallbackCompatRequestState not found in RequestContext.");
793 }
794
795 auto* resource_array = fallback_request_state->resource_array();
796 DCHECK(resource_array);
797
798 int64_t index = index_attr.GetValue();
799
800 // Setting the resource tensor to be immutable, so that we don't need
801 // reference counting on it and that it cannot be buffer-forwarded.
802 resource_array->SetResource(
803 index,
804 tensorflow::tfrt_stub::ImmutableTensor::Create(arg.get().tensor()));
805
806 return tfrt::Chain();
807 }
808
809 // FallbackGetResource is the fallback kernel that retrieves the tensor value in
810 // the fallback's resource array.
FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Result<tfrt::Chain> out_ch,tfrt::RemainingResults results,tfrt::StringAttr device,tfrt::ArrayAttr indices_attr,const tfrt::ExecutionContext & exec_ctx)811 void FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,
812 tfrt::Result<tfrt::Chain> out_ch,
813 tfrt::RemainingResults results,
814 tfrt::StringAttr device, tfrt::ArrayAttr indices_attr,
815 const tfrt::ExecutionContext& exec_ctx) {
816 tensorflow::profiler::TraceMe trace_me("tfrt_fallback_async.get_resource");
817 trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
818 return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
819 });
820
821 const auto* fallback_request_state =
822 exec_ctx.request_ctx()
823 ->GetDataIfExists<KernelFallbackCompatRequestState>();
824 if (!fallback_request_state) {
825 tfrt::RCReference<tfrt::AsyncValue> error = tfrt::EmitErrorAsync(
826 exec_ctx,
827 "KernelFallbackCompatRequestState not found in RequestContext.");
828 out_ch.Set(std::move(error));
829 return;
830 }
831
832 auto* resource_array = fallback_request_state->resource_array();
833 DCHECK(resource_array);
834
835 llvm::ArrayRef<int64_t> indices = indices_attr.GetValue<int64_t>();
836
837 for (int i = 0; i < indices.size(); ++i) {
838 results[i] = tfrt::FormRef(resource_array->GetResource(indices[i]));
839 }
840
841 out_ch.Set(in_ch);
842 }
843
844 // The implementation of tfrt_fallback_async.executeop kernel. It executes a
845 // non-side-effecting TF op with the name of `op_name` in fallback. All relevant
846 // TF attributes are passed in `op_attr_array`.
FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame * frame)847 void FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame* frame) {
848 FallbackKernelAttributeFrame attr_frame(frame);
849 #ifndef NDEBUG
850 frame->GetExecutionContext()
851 .host()
852 ->GetOrCreateSharedContext<OpLogger>()
853 .LogOp(attr_frame.op_name().GetValue());
854 #endif
855 KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
856 /*op_chain=*/nullptr, attr_frame,
857 frame->GetExecutionContext());
858 }
859
860 // The implementation of tfrt_fallback_async.executeop.seq kernel. It executes a
861 // side-effecting TF op with the name of `op_name` in fallback. All relevant
862 // TF attributes are passed in `op_attr_array`. `in_op_chain` and `out_op_chain`
863 // are used for side-effect visibility.
FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame * frame)864 void FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame* frame) {
865 auto all_args = frame->GetArguments();
866 DCHECK_GT(all_args.size(), 0);
867 tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(all_args[0]));
868 llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
869
870 auto all_results = frame->GetResults();
871 DCHECK_GT(all_results.size(), 0);
872 auto& out_op_chain = all_results[0];
873 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results =
874 all_results.drop_front();
875
876 KernelFallbackExecuteOp(args, results, &op_chain,
877 FallbackKernelAttributeFrame(frame),
878 frame->GetExecutionContext());
879 out_op_chain = std::move(op_chain);
880 }
881
FallbackCopyTensorIfSmall(tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::RemainingResults results)882 void FallbackCopyTensorIfSmall(
883 tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
884 tfrt::RemainingResults results) {
885 const auto& fallback_tensor = arg.get();
886 const auto& tensor = fallback_tensor.tensor();
887
888 if (!fallback_tensor.is_immutable()) {
889 // Create a new TensorBuffer which contains a new atomic counter for each
890 // result, to avoid downstream threads contending the original atomic
891 // counter.
892 for (int i = 0; i < results.size(); ++i) {
893 auto immutable_tensor =
894 tensorflow::tfrt_stub::ImmutableTensor::Create(tensor);
895 results[i] = tfrt::MakeAvailableAsyncValueRef<
896 tensorflow::tfrt_stub::FallbackTensor>(
897 std::move(immutable_tensor.tensor()));
898 }
899 } else {
900 // For immutable tensors, we just need to copy the pointer. Note that we
901 // still create a new AsyncValue in this case, to avoid atomic contention on
902 // AsyncValue's refcount.
903 for (int i = 0; i < results.size(); ++i) {
904 results[i] = tfrt::MakeAvailableAsyncValueRef<
905 tensorflow::tfrt_stub::FallbackTensor>(fallback_tensor);
906 }
907 }
908 }
909
ConstTensorProto(tfrt::StringAttr serialized_tensor_proto)910 llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstTensorProto(
911 tfrt::StringAttr serialized_tensor_proto) {
912 tensorflow::TensorProto tensor_proto;
913 if (!tensor_proto.ParseFromString(serialized_tensor_proto.GetValue().str())) {
914 return tfrt::MakeStringError("Failed to parse const tensor proto");
915 }
916
917 tensorflow::Tensor tensor;
918 if (!tensor.FromProto(tensor_proto)) {
919 return tfrt::MakeStringError("Failed to create tensor from tensor proto: ",
920 tensor_proto.ShortDebugString());
921 }
922
923 return tensorflow::tfrt_stub::FallbackTensor(std::move(tensor));
924 }
925
RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry * registry)926 void RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry* registry) {
927 registry->AddKernel("tfrt_fallback_async.const_tensor_proto",
928 TFRT_KERNEL(ConstTensorProto));
929 registry->AddKernel("tfrt_fallback_async.executeop", FallbackAsyncExecuteOp);
930 registry->AddKernel("tfrt_fallback_async.executeop.seq",
931 FallbackAsyncExecuteOpSeq);
932 registry->AddKernel("tfrt_fallback_async.copy_if_small",
933 TFRT_KERNEL(FallbackCopyTensorIfSmall));
934 registry->AddKernel("tfrt_fallback_async.createop",
935 TFRT_KERNEL(KernelFallbackCreateOp));
936 registry->AddKernel("tfrt_fallback_async.set_resource",
937 TFRT_KERNEL(FallbackSetResource));
938 registry->AddKernel("tfrt_fallback_async.get_resource",
939 TFRT_KERNEL(FallbackGetResource));
940 }
941
942 TFRT_STATIC_KERNEL_REGISTRATION(RegisterKernelFallbackCompatKernels);
943
944 } // namespace
945 } // namespace tfd
946 } // namespace tensorflow
947