• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
17 
18 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h"
25 #include "tensorflow/core/tfrt/utils/utils.h"
26 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
27 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
28 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
29 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
30 #include "tfrt/host_context/function.h"  // from @tf_runtime
31 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
32 #include "tfrt/support/aligned_buffer.h"  // from @tf_runtime
33 #include "tfrt/support/error_util.h"  // from @tf_runtime
34 #include "tfrt/support/mutex.h"  // from @tf_runtime
35 #include "tfrt/support/ref_count.h"  // from @tf_runtime
36 #include "tfrt/support/string_util.h"  // from @tf_runtime
37 
38 namespace tfrt {
39 namespace tf {
40 
41 // A reference counted `state` object that contains a BEF file, which represents
42 // a lowered FunctionDef. The CoreRuntimeOp is a callable handle to the function
43 // to be called.
44 class FunctionState : public ReferenceCounted<FunctionState> {
45  public:
CreateFunctionState(TfrtDataTypeSlice arg_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table)46   static RCReference<FunctionState> CreateFunctionState(
47       TfrtDataTypeSlice arg_types, BefBuffer bef_buffer,
48       RCReference<BEFFile> bef_file, CoreRuntimeOp fn,
49       std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table) {
50     return TakeRef(new FunctionState(arg_types, std::move(bef_buffer),
51                                      std::move(bef_file), std::move(fn),
52                                      std::move(runner_table)));
53   }
54 
GetFunc()55   const CoreRuntimeOp& GetFunc() const { return fn_; }
56 
GetArgTypes()57   const TfrtDataTypeVector& GetArgTypes() { return arg_types_; }
58 
GetRunnerTable()59   tensorflow::tfd::OpKernelRunnerTable* GetRunnerTable() {
60     return runner_table_.get();
61   }
62 
63  private:
FunctionState(TfrtDataTypeSlice arg_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table)64   FunctionState(
65       TfrtDataTypeSlice arg_types, BefBuffer bef_buffer,
66       RCReference<BEFFile> bef_file, CoreRuntimeOp fn,
67       std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table)
68       : arg_types_(arg_types.begin(), arg_types.end()),
69         bef_buffer_(std::move(bef_buffer)),
70         bef_file_(std::move(bef_file)),
71         fn_(std::move(fn)),
72         runner_table_(std::move(runner_table)) {}
73 
74   TfrtDataTypeVector arg_types_;
75   BefBuffer bef_buffer_;
76   RCReference<BEFFile> bef_file_;
77   const CoreRuntimeOp fn_;
78 
79   // This is the op_kernel cache used by kernel fallback compact mode. We will
80   // initialize this table right after lowering the function.
81   std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table_;
82 };
83 
84 // Cache for a single core runtime op or function (composite op). Thread safe.
85 class FunctionCache {
86  public:
87   // Iterate the cache and erase the op(s) with the specified op_name.
88   void RemoveFunction(string_view op_name) TFRT_EXCLUDES(cache_mu_);
89 
90   struct FunctionCacheResult {
91     RCReference<FunctionState> function_state;
92     bool is_cache_miss;
93   };
94 
95   typedef std::function<tensorflow::Status(
96       tensorflow::tfd::OpKernelRunnerTable*, RCReference<RequestContext>*)>
97       RequestCtxBuilder;
98 
99   // Helper function to look up the cache. If miss, insert the function to the
100   // cache.
101   // When the return status is OK, `result` is set.
102   tensorflow::Status GetOrAddFunction(
103       const std::string& op_name, const std::string& device_name,
104       const tensorflow::DeviceSet& device_set,
105       tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert,
106       RequestCtxBuilder request_ctx_fn, Location loc,
107       tensorflow::TfrtFunctionCompileOptions compile_options,
108       tfrt::ArrayRef<const Device*> input_devices, FunctionCacheResult* result);
109 
110   // The following helper functions are for debugging and testing only.
Size()111   size_t Size() const {
112     mutex_lock l(cache_mu_);
113     return cache_.size();
114   }
115 
Contains(string_view op_name,string_view device_name)116   bool Contains(string_view op_name, string_view device_name) const {
117     const CacheKey& cache_key{op_name.str(), device_name.str()};
118     mutex_lock l(cache_mu_);
119     return cache_.find(cache_key) != cache_.end();
120   }
121 
122  private:
123   // Note: Currently the key is a pair of op_name and device_name. New features
124   // may be added in the future.
125   struct CacheKey {
126     std::string op_name, device_name;
127 
128     bool operator==(const CacheKey& other) const {
129       return (this->op_name == other.op_name &&
130               this->device_name == other.device_name);
131     }
132   };
133 
134   struct CacheKeyHash {
operatorCacheKeyHash135     size_t operator()(const CacheKey& pair) const {
136       return std::hash<std::string>()(pair.op_name) ^
137              std::hash<std::string>()(pair.device_name);
138     }
139   };
140 
141   mutable mutex cache_mu_;
142   std::unordered_map<CacheKey, RCReference<FunctionState>, CacheKeyHash> cache_
143       TFRT_GUARDED_BY(cache_mu_);
144 };
145 
146 }  // namespace tf
147 }  // namespace tfrt
148 
149 #endif  // TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_
150