1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/eager/function_cache.h"
16
17 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph_to_functiondef.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/tfrt/eager/transform_graph_function.h"
22 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime
23 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
24 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
25 #include "tfrt/host_context/chain.h" // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
27 #include "tfrt/support/error_util.h" // from @tf_runtime
28 #include "tfrt/support/forward_decls.h" // from @tf_runtime
29
30 namespace tfrt {
31 namespace tf {
32
RemoveFunction(string_view op_name)33 void FunctionCache::RemoveFunction(string_view op_name) {
34 mutex_lock l(cache_mu_);
35 auto iter = cache_.begin();
36 while (iter != cache_.end()) {
37 if (iter->first.op_name == op_name) {
38 iter = cache_.erase(iter);
39 } else {
40 ++iter;
41 }
42 }
43 }
44
GetOrAddFunction(const std::string & op_name,const std::string & device_name,const tensorflow::DeviceSet & device_set,tensorflow::EagerContext * eager_ctx,tfrt::CoreRuntime * corert,RequestCtxBuilder request_ctx_fn,Location loc,tensorflow::TfrtFunctionCompileOptions compile_options,tfrt::ArrayRef<const Device * > input_devices,FunctionCache::FunctionCacheResult * result)45 tensorflow::Status FunctionCache::GetOrAddFunction(
46 const std::string& op_name, const std::string& device_name,
47 const tensorflow::DeviceSet& device_set,
48 tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
49 RequestCtxBuilder request_ctx_fn, Location loc,
50 tensorflow::TfrtFunctionCompileOptions compile_options,
51 tfrt::ArrayRef<const Device*> input_devices,
52 FunctionCache::FunctionCacheResult* result) {
53 const CacheKey cache_key{op_name, device_name};
54 {
55 mutex_lock l(cache_mu_);
56 auto& function_state = cache_[cache_key];
57 if (function_state) {
58 *result = FunctionCache::FunctionCacheResult{function_state, false};
59 return ::tensorflow::OkStatus();
60 }
61 }
62
63 tensorflow::FunctionLibraryDefinition* func_lib_def = eager_ctx->FuncLibDef();
64 const tensorflow::FunctionDef* fdef = func_lib_def->Find(op_name);
65 if (fdef == nullptr)
66 return tensorflow::errors::NotFound(
67 "Cannot find function from FunctionLibraryDefinition ", op_name);
68
69 // Run graph optimizations using current runtime components before converting
70 // the graph to MLIR module.
71 std::unique_ptr<tensorflow::FunctionBody> fbody;
72 TF_RETURN_IF_ERROR(tensorflow::FunctionDefToBodyHelper(
73 *fdef, tensorflow::AttrSlice(), func_lib_def, &fbody));
74
75 // Transferring out the graph ownership from fbody.
76 auto graph = std::unique_ptr<tensorflow::Graph>(fbody->graph);
77 fbody->graph = nullptr;
78
79 tensorflow::GraphDef graph_def;
80 graph->ToGraphDef(&graph_def);
81 tensorflow::FunctionLibraryDefinition reachable_lib_def =
82 func_lib_def->ReachableDefinitions(graph_def);
83
84 TF_RETURN_IF_ERROR(tensorflow::TransformGraphFunction(
85 op_name, *fdef, device_name, device_set, eager_ctx,
86 compile_options.enable_grappler, &fbody, std::move(graph), input_devices,
87 &reachable_lib_def));
88
89 BefBuffer bef_buffer;
90
91 llvm::SmallVector<tfrt::string_view, 4> device_names;
92 device_names.reserve(device_set.devices().size());
93 for (auto& d : device_set.devices()) {
94 device_names.push_back(d->name());
95 }
96
97 // Lower FunctionDef to BEF.
98 TF_RETURN_IF_ERROR(tensorflow::ConvertFunctionToBef(
99 op_name, fbody.get(), reachable_lib_def, device_names, compile_options,
100 &bef_buffer));
101
102 HostContext* host_ctx = corert->GetHostContext();
103 auto bef_file =
104 tfrt::BEFFile::Open(bef_buffer, host_ctx->GetKernelRegistry(),
105 host_ctx->diag_handler(), host_ctx->allocator());
106 if (!bef_file)
107 return tensorflow::errors::Internal(
108 "Failed to open lowered BEF for function ", op_name, ".");
109
110 const tfrt::Function* function = bef_file->GetFunction(op_name);
111 if (!function)
112 return tensorflow::errors::Internal(
113 "Failed to get function from BEF for function ", op_name, ".");
114
115 auto expected_fn = corert->MakeCompositeOp(function);
116 if (!expected_fn)
117 return tensorflow::errors::Internal(StrCat("Construct CoreRuntimeOp for ",
118 op_name.c_str(), " failed. ",
119 expected_fn.takeError()));
120
121 TfrtDataTypeVector tfrt_arg_types;
122 tensorflow::DataTypeVector tf_ret_types;
123
124 for (const auto& arg_type : fbody->arg_types) {
125 tfrt_arg_types.push_back(ConvertTfDTypeToTfrtDType(arg_type));
126 }
127
128 for (const auto& ret_type : fbody->ret_types) {
129 tf_ret_types.push_back(ret_type);
130 }
131
132 auto runner_table =
133 std::make_unique<tensorflow::tfrt_stub::OpKernelRunnerTable>();
134 RCReference<RequestContext> request_ctx;
135 TF_RETURN_IF_ERROR(request_ctx_fn(runner_table.get(), &request_ctx));
136
137 ExecutionContext exec_ctx{std::move(request_ctx), loc};
138 TF_RETURN_IF_ERROR(
139 RunRuntimeInitializer(exec_ctx, bef_file.get(), "_tfrt_fallback_init"));
140
141 RCReference<FunctionState> entry = FunctionState::CreateFunctionState(
142 tfrt_arg_types, tf_ret_types, std::move(bef_buffer), std::move(bef_file),
143 std::move(expected_fn.get()), std::move(runner_table));
144
145 mutex_lock l(cache_mu_);
146 // Insert the new entry to cache. If an entry with the same key is already
147 // present in the cache at this moment due to race condition, overwrites it.
148 cache_[cache_key] = entry;
149 *result = FunctionCache::FunctionCacheResult{std::move(entry), true};
150 return ::tensorflow::OkStatus();
151 }
152
153 } // namespace tf
154 } // namespace tfrt
155