• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/eager/function_cache.h"
16 
17 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph_to_functiondef.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/tfrt/eager/transform_graph_function.h"
22 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
23 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
24 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
25 #include "tfrt/host_context/chain.h"  // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
27 #include "tfrt/support/error_util.h"  // from @tf_runtime
28 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
29 
30 namespace tfrt {
31 namespace tf {
32 
RemoveFunction(string_view op_name)33 void FunctionCache::RemoveFunction(string_view op_name) {
34   mutex_lock l(cache_mu_);
35   auto iter = cache_.begin();
36   while (iter != cache_.end()) {
37     if (iter->first.op_name == op_name) {
38       iter = cache_.erase(iter);
39     } else {
40       ++iter;
41     }
42   }
43 }
44 
GetOrAddFunction(const std::string & op_name,const std::string & device_name,const tensorflow::DeviceSet & device_set,tensorflow::EagerContext * eager_ctx,tfrt::CoreRuntime * corert,RequestCtxBuilder request_ctx_fn,Location loc,tensorflow::TfrtFunctionCompileOptions compile_options,tfrt::ArrayRef<const Device * > input_devices,FunctionCache::FunctionCacheResult * result)45 tensorflow::Status FunctionCache::GetOrAddFunction(
46     const std::string& op_name, const std::string& device_name,
47     const tensorflow::DeviceSet& device_set,
48     tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
49     RequestCtxBuilder request_ctx_fn, Location loc,
50     tensorflow::TfrtFunctionCompileOptions compile_options,
51     tfrt::ArrayRef<const Device*> input_devices,
52     FunctionCache::FunctionCacheResult* result) {
53   const CacheKey cache_key{op_name, device_name};
54   {
55     mutex_lock l(cache_mu_);
56     auto& function_state = cache_[cache_key];
57     if (function_state) {
58       *result = FunctionCache::FunctionCacheResult{function_state, false};
59       return ::tensorflow::OkStatus();
60     }
61   }
62 
63   tensorflow::FunctionLibraryDefinition* func_lib_def = eager_ctx->FuncLibDef();
64   const tensorflow::FunctionDef* fdef = func_lib_def->Find(op_name);
65   if (fdef == nullptr)
66     return tensorflow::errors::NotFound(
67         "Cannot find function from FunctionLibraryDefinition ", op_name);
68 
69   // Run graph optimizations using current runtime components before converting
70   // the graph to MLIR module.
71   std::unique_ptr<tensorflow::FunctionBody> fbody;
72   TF_RETURN_IF_ERROR(tensorflow::FunctionDefToBodyHelper(
73       *fdef, tensorflow::AttrSlice(), func_lib_def, &fbody));
74 
75   // Transferring out the graph ownership from fbody.
76   auto graph = std::unique_ptr<tensorflow::Graph>(fbody->graph);
77   fbody->graph = nullptr;
78 
79   tensorflow::GraphDef graph_def;
80   graph->ToGraphDef(&graph_def);
81   tensorflow::FunctionLibraryDefinition reachable_lib_def =
82       func_lib_def->ReachableDefinitions(graph_def);
83 
84   TF_RETURN_IF_ERROR(tensorflow::TransformGraphFunction(
85       op_name, *fdef, device_name, device_set, eager_ctx,
86       compile_options.enable_grappler, &fbody, std::move(graph), input_devices,
87       &reachable_lib_def));
88 
89   BefBuffer bef_buffer;
90 
91   llvm::SmallVector<tfrt::string_view, 4> device_names;
92   device_names.reserve(device_set.devices().size());
93   for (auto& d : device_set.devices()) {
94     device_names.push_back(d->name());
95   }
96 
97   // Lower FunctionDef to BEF.
98   TF_RETURN_IF_ERROR(tensorflow::ConvertFunctionToBef(
99       op_name, fbody.get(), reachable_lib_def, device_names, compile_options,
100       &bef_buffer));
101 
102   HostContext* host_ctx = corert->GetHostContext();
103   auto bef_file =
104       tfrt::BEFFile::Open(bef_buffer, host_ctx->GetKernelRegistry(),
105                           host_ctx->diag_handler(), host_ctx->allocator());
106   if (!bef_file)
107     return tensorflow::errors::Internal(
108         "Failed to open lowered BEF for function ", op_name, ".");
109 
110   const tfrt::Function* function = bef_file->GetFunction(op_name);
111   if (!function)
112     return tensorflow::errors::Internal(
113         "Failed to get function from BEF for function ", op_name, ".");
114 
115   auto expected_fn = corert->MakeCompositeOp(function);
116   if (!expected_fn)
117     return tensorflow::errors::Internal(StrCat("Construct CoreRuntimeOp for ",
118                                                op_name.c_str(), " failed. ",
119                                                expected_fn.takeError()));
120 
121   TfrtDataTypeVector tfrt_arg_types;
122   tensorflow::DataTypeVector tf_ret_types;
123 
124   for (const auto& arg_type : fbody->arg_types) {
125     tfrt_arg_types.push_back(ConvertTfDTypeToTfrtDType(arg_type));
126   }
127 
128   for (const auto& ret_type : fbody->ret_types) {
129     tf_ret_types.push_back(ret_type);
130   }
131 
132   auto runner_table =
133       std::make_unique<tensorflow::tfrt_stub::OpKernelRunnerTable>();
134   RCReference<RequestContext> request_ctx;
135   TF_RETURN_IF_ERROR(request_ctx_fn(runner_table.get(), &request_ctx));
136 
137   ExecutionContext exec_ctx{std::move(request_ctx), loc};
138   TF_RETURN_IF_ERROR(
139       RunRuntimeInitializer(exec_ctx, bef_file.get(), "_tfrt_fallback_init"));
140 
141   RCReference<FunctionState> entry = FunctionState::CreateFunctionState(
142       tfrt_arg_types, tf_ret_types, std::move(bef_buffer), std::move(bef_file),
143       std::move(expected_fn.get()), std::move(runner_table));
144 
145   mutex_lock l(cache_mu_);
146   // Insert the new entry to cache. If an entry with the same key is already
147   // present in the cache at this moment due to race condition, overwrites it.
148   cache_[cache_key] = entry;
149   *result = FunctionCache::FunctionCacheResult{std::move(entry), true};
150   return ::tensorflow::OkStatus();
151 }
152 
153 }  // namespace tf
154 }  // namespace tfrt
155