• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h"
16 
17 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_request_context.h"
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
20 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
21 
22 namespace tensorflow {
23 namespace tfd {
24 
CreateFallbackTestExecutionContext(tfrt::HostContext * host,tfrt::ResourceContext * resource_context,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool)25 tfrt::ExecutionContext CreateFallbackTestExecutionContext(
26     tfrt::HostContext* host, tfrt::ResourceContext* resource_context,
27     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool) {
28   static std::atomic<int64> id{0};
29 
30   // We should better decouple eager context and resource context. In prod code,
31   // we shouldn't store eager context in resource context.
32   auto* eager_context_resource =
33       resource_context->GetOrCreateResource<EagerContextResource>(
34           tensorflow::tfd::kEagerContextResourceName);
35   assert(eager_context_resource);
36   auto expected_eager_context = eager_context_resource->GetTFEagerContext();
37   assert(expected_eager_context);
38   auto* eager_context = expected_eager_context.get();
39   assert(eager_context);
40 
41   // Add a dummy FunctionDef to test creating ops with function attributes.
42   const FunctionDef dummy_function_def = FunctionDefHelper::Define(
43       /*function_name=*/"dummy_fn",
44       /*arg_def=*/{},
45       /*return values=*/{},
46       /*attr def=*/{},
47       /*node_def=*/{});
48   tensorflow::Status status = eager_context->AddFunctionDef(dummy_function_def);
49   TF_DCHECK_OK(status);
50 
51   auto request_id = id.fetch_add(1);
52   tfrt::RequestContextBuilder request_context_builder(host, resource_context,
53                                                       request_id);
54   status = SetUpKernelFallbackCompatRequestContext(
55       &request_context_builder, eager_context->local_device_mgr(),
56       eager_context->pflr(), user_intra_op_threadpool);
57   TF_DCHECK_OK(status);
58 
59   status = SetUpTfCpuRtRequestContext(&request_context_builder);
60   TF_DCHECK_OK(status);
61 
62   auto request_context = std::move(request_context_builder).build();
63   assert(request_context);
64 
65   return tfrt::ExecutionContext{std::move(request_context.get())};
66 }
67 
68 }  // namespace tfd
69 }  // namespace tensorflow
70