• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/eager/function_cache.h"
16 
17 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph_to_functiondef.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/tfrt/eager/transform_graph_function.h"
22 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
23 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
24 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
25 #include "tfrt/host_context/chain.h"  // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
27 #include "tfrt/support/error_util.h"  // from @tf_runtime
28 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
29 
30 namespace tfrt {
31 namespace tf {
32 
RemoveFunction(string_view op_name)33 void FunctionCache::RemoveFunction(string_view op_name) {
34   mutex_lock l(cache_mu_);
35   auto iter = cache_.begin();
36   while (iter != cache_.end()) {
37     if (iter->first.op_name == op_name) {
38       iter = cache_.erase(iter);
39     } else {
40       ++iter;
41     }
42   }
43 }
44 
GetOrAddFunction(const std::string & op_name,const std::string & device_name,const tensorflow::DeviceSet & device_set,tensorflow::EagerContext * eager_ctx,tfrt::CoreRuntime * corert,RequestCtxBuilder request_ctx_fn,Location loc,tensorflow::TfrtFunctionCompileOptions compile_options,tfrt::ArrayRef<const Device * > input_devices,FunctionCache::FunctionCacheResult * result)45 tensorflow::Status FunctionCache::GetOrAddFunction(
46     const std::string& op_name, const std::string& device_name,
47     const tensorflow::DeviceSet& device_set,
48     tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
49     RequestCtxBuilder request_ctx_fn, Location loc,
50     tensorflow::TfrtFunctionCompileOptions compile_options,
51     tfrt::ArrayRef<const Device*> input_devices,
52     FunctionCache::FunctionCacheResult* result) {
53   const CacheKey cache_key{op_name, device_name};
54   {
55     mutex_lock l(cache_mu_);
56     auto& function_state = cache_[cache_key];
57     if (function_state) {
58       *result =
59           FunctionCache::FunctionCacheResult{function_state.CopyRef(), false};
60       return tensorflow::Status::OK();
61     }
62   }
63 
64   tensorflow::FunctionLibraryDefinition* func_lib_def = eager_ctx->FuncLibDef();
65   const tensorflow::FunctionDef* fdef = func_lib_def->Find(op_name);
66   if (fdef == nullptr)
67     return tensorflow::errors::NotFound(
68         "Cannot find function from FunctionLibraryDefinition ", op_name);
69 
70   // Run graph optimizations using current runtime components before converting
71   // the graph to MLIR module.
72   std::unique_ptr<tensorflow::FunctionBody> fbody;
73   TF_RETURN_IF_ERROR(tensorflow::FunctionDefToBodyHelper(
74       *fdef, tensorflow::AttrSlice(), func_lib_def, &fbody));
75 
76   // Transferring out the graph ownership from fbody.
77   auto graph = std::unique_ptr<tensorflow::Graph>(fbody->graph);
78   fbody->graph = nullptr;
79 
80   tensorflow::GraphDef graph_def;
81   graph->ToGraphDef(&graph_def);
82   tensorflow::FunctionLibraryDefinition reachable_lib_def =
83       func_lib_def->ReachableDefinitions(graph_def);
84 
85   TF_RETURN_IF_ERROR(tensorflow::TransformGraphFunction(
86       op_name, *fdef, device_name, device_set, eager_ctx,
87       compile_options.enable_grappler, &fbody, std::move(graph), input_devices,
88       &reachable_lib_def));
89 
90   BefBuffer bef_buffer;
91 
92   llvm::SmallVector<tfrt::string_view, 4> device_names;
93   device_names.reserve(device_set.devices().size());
94   for (auto& d : device_set.devices()) {
95     device_names.push_back(d->name());
96   }
97 
98   // Lower FunctionDef to BEF.
99   TF_RETURN_IF_ERROR(tensorflow::ConvertFunctionToBef(
100       op_name, fbody.get(), reachable_lib_def, device_names, compile_options,
101       &bef_buffer));
102 
103   HostContext* host_ctx = corert->GetHostContext();
104   auto bef_file =
105       tfrt::BEFFile::Open(bef_buffer, host_ctx->GetKernelRegistry(),
106                           host_ctx->diag_handler(), host_ctx->allocator());
107   if (!bef_file)
108     return tensorflow::errors::Internal(
109         "Failed to open lowered BEF for function ", op_name, ".");
110 
111   const tfrt::Function* function = bef_file->GetFunction(op_name);
112   if (!function)
113     return tensorflow::errors::Internal(
114         "Failed to get function from BEF for function ", op_name, ".");
115 
116   auto expected_fn = corert->MakeCompositeOp(function);
117   if (!expected_fn)
118     return tensorflow::errors::Internal(StrCat("Construct CoreRuntimeOp for ",
119                                                op_name.c_str(), " failed. ",
120                                                expected_fn.takeError()));
121 
122   TfrtDataTypeVector tfrt_arg_types;
123 
124   for (const auto& arg_type : fbody->arg_types) {
125     tfrt_arg_types.push_back(ConvertTfDTypeToTfrtDType(arg_type));
126   }
127 
128   auto runner_table = absl::make_unique<tensorflow::tfd::OpKernelRunnerTable>();
129   RCReference<RequestContext> request_ctx;
130   TF_RETURN_IF_ERROR(request_ctx_fn(runner_table.get(), &request_ctx));
131 
132   ExecutionContext exec_ctx{std::move(request_ctx), loc};
133   TF_RETURN_IF_ERROR(
134       RunRuntimeInitializer(exec_ctx, bef_file.get(), "_tfrt_fallback_init"));
135 
136   RCReference<FunctionState> entry = FunctionState::CreateFunctionState(
137       tfrt_arg_types, std::move(bef_buffer), std::move(bef_file),
138       std::move(expected_fn.get()), std::move(runner_table));
139 
140   mutex_lock l(cache_mu_);
141   // Insert the new entry to cache. If an entry with the same key is already
142   // present in the cache at this moment due to race condition, overwrites it.
143   cache_[cache_key] = entry.CopyRef();
144   *result = FunctionCache::FunctionCacheResult{std::move(entry), true};
145   return tensorflow::Status::OK();
146 }
147 
148 }  // namespace tf
149 }  // namespace tfrt
150