1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/eager/function_cache.h"
16
17 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph_to_functiondef.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/tfrt/eager/transform_graph_function.h"
22 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime
23 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
24 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
25 #include "tfrt/host_context/chain.h" // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
27 #include "tfrt/support/error_util.h" // from @tf_runtime
28 #include "tfrt/support/forward_decls.h" // from @tf_runtime
29
30 namespace tfrt {
31 namespace tf {
32
RemoveFunction(string_view op_name)33 void FunctionCache::RemoveFunction(string_view op_name) {
34 mutex_lock l(cache_mu_);
35 auto iter = cache_.begin();
36 while (iter != cache_.end()) {
37 if (iter->first.op_name == op_name) {
38 iter = cache_.erase(iter);
39 } else {
40 ++iter;
41 }
42 }
43 }
44
GetOrAddFunction(const std::string & op_name,const std::string & device_name,const tensorflow::DeviceSet & device_set,tensorflow::EagerContext * eager_ctx,tfrt::CoreRuntime * corert,RequestCtxBuilder request_ctx_fn,Location loc,tensorflow::TfrtFunctionCompileOptions compile_options,tfrt::ArrayRef<const Device * > input_devices,FunctionCache::FunctionCacheResult * result)45 tensorflow::Status FunctionCache::GetOrAddFunction(
46 const std::string& op_name, const std::string& device_name,
47 const tensorflow::DeviceSet& device_set,
48 tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
49 RequestCtxBuilder request_ctx_fn, Location loc,
50 tensorflow::TfrtFunctionCompileOptions compile_options,
51 tfrt::ArrayRef<const Device*> input_devices,
52 FunctionCache::FunctionCacheResult* result) {
53 const CacheKey cache_key{op_name, device_name};
54 {
55 mutex_lock l(cache_mu_);
56 auto& function_state = cache_[cache_key];
57 if (function_state) {
58 *result =
59 FunctionCache::FunctionCacheResult{function_state.CopyRef(), false};
60 return tensorflow::Status::OK();
61 }
62 }
63
64 tensorflow::FunctionLibraryDefinition* func_lib_def = eager_ctx->FuncLibDef();
65 const tensorflow::FunctionDef* fdef = func_lib_def->Find(op_name);
66 if (fdef == nullptr)
67 return tensorflow::errors::NotFound(
68 "Cannot find function from FunctionLibraryDefinition ", op_name);
69
70 // Run graph optimizations using current runtime components before converting
71 // the graph to MLIR module.
72 std::unique_ptr<tensorflow::FunctionBody> fbody;
73 TF_RETURN_IF_ERROR(tensorflow::FunctionDefToBodyHelper(
74 *fdef, tensorflow::AttrSlice(), func_lib_def, &fbody));
75
76 // Transferring out the graph ownership from fbody.
77 auto graph = std::unique_ptr<tensorflow::Graph>(fbody->graph);
78 fbody->graph = nullptr;
79
80 tensorflow::GraphDef graph_def;
81 graph->ToGraphDef(&graph_def);
82 tensorflow::FunctionLibraryDefinition reachable_lib_def =
83 func_lib_def->ReachableDefinitions(graph_def);
84
85 TF_RETURN_IF_ERROR(tensorflow::TransformGraphFunction(
86 op_name, *fdef, device_name, device_set, eager_ctx,
87 compile_options.enable_grappler, &fbody, std::move(graph), input_devices,
88 &reachable_lib_def));
89
90 BefBuffer bef_buffer;
91
92 llvm::SmallVector<tfrt::string_view, 4> device_names;
93 device_names.reserve(device_set.devices().size());
94 for (auto& d : device_set.devices()) {
95 device_names.push_back(d->name());
96 }
97
98 // Lower FunctionDef to BEF.
99 TF_RETURN_IF_ERROR(tensorflow::ConvertFunctionToBef(
100 op_name, fbody.get(), reachable_lib_def, device_names, compile_options,
101 &bef_buffer));
102
103 HostContext* host_ctx = corert->GetHostContext();
104 auto bef_file =
105 tfrt::BEFFile::Open(bef_buffer, host_ctx->GetKernelRegistry(),
106 host_ctx->diag_handler(), host_ctx->allocator());
107 if (!bef_file)
108 return tensorflow::errors::Internal(
109 "Failed to open lowered BEF for function ", op_name, ".");
110
111 const tfrt::Function* function = bef_file->GetFunction(op_name);
112 if (!function)
113 return tensorflow::errors::Internal(
114 "Failed to get function from BEF for function ", op_name, ".");
115
116 auto expected_fn = corert->MakeCompositeOp(function);
117 if (!expected_fn)
118 return tensorflow::errors::Internal(StrCat("Construct CoreRuntimeOp for ",
119 op_name.c_str(), " failed. ",
120 expected_fn.takeError()));
121
122 TfrtDataTypeVector tfrt_arg_types;
123
124 for (const auto& arg_type : fbody->arg_types) {
125 tfrt_arg_types.push_back(ConvertTfDTypeToTfrtDType(arg_type));
126 }
127
128 auto runner_table = absl::make_unique<tensorflow::tfd::OpKernelRunnerTable>();
129 RCReference<RequestContext> request_ctx;
130 TF_RETURN_IF_ERROR(request_ctx_fn(runner_table.get(), &request_ctx));
131
132 ExecutionContext exec_ctx{std::move(request_ctx), loc};
133 TF_RETURN_IF_ERROR(
134 RunRuntimeInitializer(exec_ctx, bef_file.get(), "_tfrt_fallback_init"));
135
136 RCReference<FunctionState> entry = FunctionState::CreateFunctionState(
137 tfrt_arg_types, std::move(bef_buffer), std::move(bef_file),
138 std::move(expected_fn.get()), std::move(runner_table));
139
140 mutex_lock l(cache_mu_);
141 // Insert the new entry to cache. If an entry with the same key is already
142 // present in the cache at this moment due to race condition, overwrites it.
143 cache_[cache_key] = entry.CopyRef();
144 *result = FunctionCache::FunctionCacheResult{std::move(entry), true};
145 return tensorflow::Status::OK();
146 }
147
148 } // namespace tf
149 } // namespace tfrt
150