• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
17 
18 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
25 #include "tensorflow/core/tfrt/utils/utils.h"
26 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
27 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
28 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
29 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
30 #include "tfrt/host_context/function.h"  // from @tf_runtime
31 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
32 #include "tfrt/support/aligned_buffer.h"  // from @tf_runtime
33 #include "tfrt/support/error_util.h"  // from @tf_runtime
34 #include "tfrt/support/mutex.h"  // from @tf_runtime
35 #include "tfrt/support/ref_count.h"  // from @tf_runtime
36 #include "tfrt/support/string_util.h"  // from @tf_runtime
37 
38 namespace tfrt {
39 namespace tf {
40 
41 // A reference counted `state` object that contains a BEF file, which represents
42 // a lowered FunctionDef. The CoreRuntimeOp is a callable handle to the function
43 // to be called.
44 class FunctionState : public ReferenceCounted<FunctionState> {
45  public:
CreateFunctionState(TfrtDataTypeSlice arg_types,tensorflow::DataTypeSlice ret_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table)46   static RCReference<FunctionState> CreateFunctionState(
47       TfrtDataTypeSlice arg_types, tensorflow::DataTypeSlice ret_types,
48       BefBuffer bef_buffer, RCReference<BEFFile> bef_file, CoreRuntimeOp fn,
49       std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable>
50           runner_table) {
51     return TakeRef(new FunctionState(arg_types, ret_types,
52                                      std::move(bef_buffer), std::move(bef_file),
53                                      std::move(fn), std::move(runner_table)));
54   }
55 
GetFunc()56   const CoreRuntimeOp& GetFunc() const { return fn_; }
57 
GetArgTypes()58   const TfrtDataTypeVector& GetArgTypes() { return arg_types_; }
59 
GetRetTypes()60   const tensorflow::DataTypeVector& GetRetTypes() { return ret_types_; }
61 
GetRunnerTable()62   tensorflow::tfrt_stub::OpKernelRunnerTable* GetRunnerTable() {
63     return runner_table_.get();
64   }
65 
66  private:
FunctionState(TfrtDataTypeSlice arg_types,tensorflow::DataTypeSlice ret_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table)67   FunctionState(
68       TfrtDataTypeSlice arg_types, tensorflow::DataTypeSlice ret_types,
69       BefBuffer bef_buffer, RCReference<BEFFile> bef_file, CoreRuntimeOp fn,
70       std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table)
71       : arg_types_(arg_types.begin(), arg_types.end()),
72         ret_types_(ret_types.begin(), ret_types.end()),
73         bef_buffer_(std::move(bef_buffer)),
74         bef_file_(std::move(bef_file)),
75         fn_(std::move(fn)),
76         runner_table_(std::move(runner_table)) {}
77 
78   TfrtDataTypeVector arg_types_;
79   tensorflow::DataTypeVector ret_types_;
80   BefBuffer bef_buffer_;
81   RCReference<BEFFile> bef_file_;
82   const CoreRuntimeOp fn_;
83 
84   // This is the op_kernel cache used by kernel fallback compact mode. We will
85   // initialize this table right after lowering the function.
86   std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table_;
87 };
88 
89 // Cache for a single core runtime op or function (composite op). Thread safe.
90 class FunctionCache {
91  public:
92   // Iterate the cache and erase the op(s) with the specified op_name.
93   void RemoveFunction(string_view op_name) TFRT_EXCLUDES(cache_mu_);
94 
95   struct FunctionCacheResult {
96     RCReference<FunctionState> function_state;
97     bool is_cache_miss;
98   };
99 
100   typedef std::function<tensorflow::Status(
101       tensorflow::tfrt_stub::OpKernelRunnerTable*,
102       RCReference<RequestContext>*)>
103       RequestCtxBuilder;
104 
105   // Helper function to look up the cache. If miss, insert the function to the
106   // cache.
107   // When the return status is OK, `result` is set.
108   tensorflow::Status GetOrAddFunction(
109       const std::string& op_name, const std::string& device_name,
110       const tensorflow::DeviceSet& device_set,
111       tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
112       RequestCtxBuilder request_ctx_fn, Location loc,
113       tensorflow::TfrtFunctionCompileOptions compile_options,
114       tfrt::ArrayRef<const Device*> input_devices, FunctionCacheResult* result);
115 
116   // The following helper functions are for debugging and testing only.
Size()117   size_t Size() const {
118     mutex_lock l(cache_mu_);
119     return cache_.size();
120   }
121 
Contains(string_view op_name,string_view device_name)122   bool Contains(string_view op_name, string_view device_name) const {
123     const CacheKey& cache_key{op_name.str(), device_name.str()};
124     mutex_lock l(cache_mu_);
125     return cache_.find(cache_key) != cache_.end();
126   }
127 
128  private:
129   // Note: Currently the key is a pair of op_name and device_name. New features
130   // may be added in the future.
131   struct CacheKey {
132     std::string op_name, device_name;
133 
134     bool operator==(const CacheKey& other) const {
135       return (this->op_name == other.op_name &&
136               this->device_name == other.device_name);
137     }
138   };
139 
140   struct CacheKeyHash {
operatorCacheKeyHash141     size_t operator()(const CacheKey& pair) const {
142       return std::hash<std::string>()(pair.op_name) ^
143              std::hash<std::string>()(pair.device_name);
144     }
145   };
146 
147   mutable mutex cache_mu_;
148   std::unordered_map<CacheKey, RCReference<FunctionState>, CacheKeyHash> cache_
149       TFRT_GUARDED_BY(cache_mu_);
150 };
151 
152 }  // namespace tf
153 }  // namespace tfrt
154 
155 #endif  // TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
156