1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 16 #define TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 17 18 #include "tensorflow/compiler/mlir/tfrt/function/function.h" 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/platform/status.h" 24 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h" 25 #include "tensorflow/core/tfrt/utils/utils.h" 26 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime 27 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime 28 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime 29 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime 30 #include "tfrt/host_context/function.h" // from @tf_runtime 31 #include "tfrt/host_context/host_context.h" // from @tf_runtime 32 #include "tfrt/support/aligned_buffer.h" // from @tf_runtime 33 #include "tfrt/support/error_util.h" // from @tf_runtime 34 #include "tfrt/support/mutex.h" // from @tf_runtime 35 #include "tfrt/support/ref_count.h" // from @tf_runtime 36 #include "tfrt/support/string_util.h" // from @tf_runtime 37 38 namespace tfrt { 39 namespace tf { 40 41 // A reference counted `state` object that contains a BEF file, which represents 42 // a lowered FunctionDef. The CoreRuntimeOp is a callable handle to the function 43 // to be called. 44 class FunctionState : public ReferenceCounted<FunctionState> { 45 public: CreateFunctionState(TfrtDataTypeSlice arg_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table)46 static RCReference<FunctionState> CreateFunctionState( 47 TfrtDataTypeSlice arg_types, BefBuffer bef_buffer, 48 RCReference<BEFFile> bef_file, CoreRuntimeOp fn, 49 std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table) { 50 return TakeRef(new FunctionState(arg_types, std::move(bef_buffer), 51 std::move(bef_file), std::move(fn), 52 std::move(runner_table))); 53 } 54 GetFunc()55 const CoreRuntimeOp& GetFunc() const { return fn_; } 56 GetArgTypes()57 const TfrtDataTypeVector& GetArgTypes() { return arg_types_; } 58 GetRunnerTable()59 tensorflow::tfd::OpKernelRunnerTable* GetRunnerTable() { 60 return runner_table_.get(); 61 } 62 63 private: FunctionState(TfrtDataTypeSlice arg_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table)64 FunctionState( 65 TfrtDataTypeSlice arg_types, BefBuffer bef_buffer, 66 RCReference<BEFFile> bef_file, CoreRuntimeOp fn, 67 std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table) 68 : arg_types_(arg_types.begin(), arg_types.end()), 69 bef_buffer_(std::move(bef_buffer)), 70 bef_file_(std::move(bef_file)), 71 fn_(std::move(fn)), 72 runner_table_(std::move(runner_table)) {} 73 74 TfrtDataTypeVector arg_types_; 75 BefBuffer bef_buffer_; 76 RCReference<BEFFile> bef_file_; 77 const CoreRuntimeOp fn_; 78 79 // This is the op_kernel cache used by kernel fallback compact mode. We will 80 // initialize this table right after lowering the function. 81 std::unique_ptr<tensorflow::tfd::OpKernelRunnerTable> runner_table_; 82 }; 83 84 // Cache for a single core runtime op or function (composite op). Thread safe. 85 class FunctionCache { 86 public: 87 // Iterate the cache and erase the op(s) with the specified op_name. 88 void RemoveFunction(string_view op_name) TFRT_EXCLUDES(cache_mu_); 89 90 struct FunctionCacheResult { 91 RCReference<FunctionState> function_state; 92 bool is_cache_miss; 93 }; 94 95 typedef std::function<tensorflow::Status( 96 tensorflow::tfd::OpKernelRunnerTable*, RCReference<RequestContext>*)> 97 RequestCtxBuilder; 98 99 // Helper function to look up the cache. If miss, insert the function to the 100 // cache. 101 // When the return status is OK, `result` is set. 102 tensorflow::Status GetOrAddFunction( 103 const std::string& op_name, const std::string& device_name, 104 const tensorflow::DeviceSet& device_set, 105 tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert, 106 RequestCtxBuilder request_ctx_fn, Location loc, 107 tensorflow::TfrtFunctionCompileOptions compile_options, 108 tfrt::ArrayRef<const Device*> input_devices, FunctionCacheResult* result); 109 110 // The following helper functions are for debugging and testing only. Size()111 size_t Size() const { 112 mutex_lock l(cache_mu_); 113 return cache_.size(); 114 } 115 Contains(string_view op_name,string_view device_name)116 bool Contains(string_view op_name, string_view device_name) const { 117 const CacheKey& cache_key{op_name.str(), device_name.str()}; 118 mutex_lock l(cache_mu_); 119 return cache_.find(cache_key) != cache_.end(); 120 } 121 122 private: 123 // Note: Currently the key is a pair of op_name and device_name. New features 124 // may be added in the future. 125 struct CacheKey { 126 std::string op_name, device_name; 127 128 bool operator==(const CacheKey& other) const { 129 return (this->op_name == other.op_name && 130 this->device_name == other.device_name); 131 } 132 }; 133 134 struct CacheKeyHash { operatorCacheKeyHash135 size_t operator()(const CacheKey& pair) const { 136 return std::hash<std::string>()(pair.op_name) ^ 137 std::hash<std::string>()(pair.device_name); 138 } 139 }; 140 141 mutable mutex cache_mu_; 142 std::unordered_map<CacheKey, RCReference<FunctionState>, CacheKeyHash> cache_ 143 TFRT_GUARDED_BY(cache_mu_); 144 }; 145 146 } // namespace tf 147 } // namespace tfrt 148 149 #endif // TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 150