1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 16 #define TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 17 18 #include "tensorflow/compiler/mlir/tfrt/function/function.h" 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/platform/status.h" 24 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" 25 #include "tensorflow/core/tfrt/utils/utils.h" 26 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime 27 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime 28 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime 29 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime 30 #include "tfrt/host_context/function.h" // from @tf_runtime 31 #include "tfrt/host_context/host_context.h" // from @tf_runtime 32 #include "tfrt/support/aligned_buffer.h" // from @tf_runtime 33 #include "tfrt/support/error_util.h" // from @tf_runtime 34 #include "tfrt/support/mutex.h" // from @tf_runtime 35 #include "tfrt/support/ref_count.h" // from @tf_runtime 36 #include "tfrt/support/string_util.h" // from @tf_runtime 37 38 namespace tfrt { 39 namespace tf { 40 41 // A reference counted `state` object that contains a BEF file, which represents 42 // a lowered FunctionDef. The CoreRuntimeOp is a callable handle to the function 43 // to be called. 44 class FunctionState : public ReferenceCounted<FunctionState> { 45 public: CreateFunctionState(TfrtDataTypeSlice arg_types,tensorflow::DataTypeSlice ret_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table)46 static RCReference<FunctionState> CreateFunctionState( 47 TfrtDataTypeSlice arg_types, tensorflow::DataTypeSlice ret_types, 48 BefBuffer bef_buffer, RCReference<BEFFile> bef_file, CoreRuntimeOp fn, 49 std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> 50 runner_table) { 51 return TakeRef(new FunctionState(arg_types, ret_types, 52 std::move(bef_buffer), std::move(bef_file), 53 std::move(fn), std::move(runner_table))); 54 } 55 GetFunc()56 const CoreRuntimeOp& GetFunc() const { return fn_; } 57 GetArgTypes()58 const TfrtDataTypeVector& GetArgTypes() { return arg_types_; } 59 GetRetTypes()60 const tensorflow::DataTypeVector& GetRetTypes() { return ret_types_; } 61 GetRunnerTable()62 tensorflow::tfrt_stub::OpKernelRunnerTable* GetRunnerTable() { 63 return runner_table_.get(); 64 } 65 66 private: FunctionState(TfrtDataTypeSlice arg_types,tensorflow::DataTypeSlice ret_types,BefBuffer bef_buffer,RCReference<BEFFile> bef_file,CoreRuntimeOp fn,std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table)67 FunctionState( 68 TfrtDataTypeSlice arg_types, tensorflow::DataTypeSlice ret_types, 69 BefBuffer bef_buffer, RCReference<BEFFile> bef_file, CoreRuntimeOp fn, 70 std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table) 71 : arg_types_(arg_types.begin(), arg_types.end()), 72 ret_types_(ret_types.begin(), ret_types.end()), 73 bef_buffer_(std::move(bef_buffer)), 74 bef_file_(std::move(bef_file)), 75 fn_(std::move(fn)), 76 runner_table_(std::move(runner_table)) {} 77 78 TfrtDataTypeVector arg_types_; 79 tensorflow::DataTypeVector ret_types_; 80 BefBuffer bef_buffer_; 81 RCReference<BEFFile> bef_file_; 82 const CoreRuntimeOp fn_; 83 84 // This is the op_kernel cache used by kernel fallback compact mode. We will 85 // initialize this table right after lowering the function. 86 std::unique_ptr<tensorflow::tfrt_stub::OpKernelRunnerTable> runner_table_; 87 }; 88 89 // Cache for a single core runtime op or function (composite op). Thread safe. 90 class FunctionCache { 91 public: 92 // Iterate the cache and erase the op(s) with the specified op_name. 93 void RemoveFunction(string_view op_name) TFRT_EXCLUDES(cache_mu_); 94 95 struct FunctionCacheResult { 96 RCReference<FunctionState> function_state; 97 bool is_cache_miss; 98 }; 99 100 typedef std::function<tensorflow::Status( 101 tensorflow::tfrt_stub::OpKernelRunnerTable*, 102 RCReference<RequestContext>*)> 103 RequestCtxBuilder; 104 105 // Helper function to look up the cache. If miss, insert the function to the 106 // cache. 107 // When the return status is OK, `result` is set. 108 tensorflow::Status GetOrAddFunction( 109 const std::string& op_name, const std::string& device_name, 110 const tensorflow::DeviceSet& device_set, 111 tensorflow::EagerContext* eager_ctx, tfrt::CoreRuntime* corert, 112 RequestCtxBuilder request_ctx_fn, Location loc, 113 tensorflow::TfrtFunctionCompileOptions compile_options, 114 tfrt::ArrayRef<const Device*> input_devices, FunctionCacheResult* result); 115 116 // The following helper functions are for debugging and testing only. Size()117 size_t Size() const { 118 mutex_lock l(cache_mu_); 119 return cache_.size(); 120 } 121 Contains(string_view op_name,string_view device_name)122 bool Contains(string_view op_name, string_view device_name) const { 123 const CacheKey& cache_key{op_name.str(), device_name.str()}; 124 mutex_lock l(cache_mu_); 125 return cache_.find(cache_key) != cache_.end(); 126 } 127 128 private: 129 // Note: Currently the key is a pair of op_name and device_name. New features 130 // may be added in the future. 131 struct CacheKey { 132 std::string op_name, device_name; 133 134 bool operator==(const CacheKey& other) const { 135 return (this->op_name == other.op_name && 136 this->device_name == other.device_name); 137 } 138 }; 139 140 struct CacheKeyHash { operatorCacheKeyHash141 size_t operator()(const CacheKey& pair) const { 142 return std::hash<std::string>()(pair.op_name) ^ 143 std::hash<std::string>()(pair.device_name); 144 } 145 }; 146 147 mutable mutex cache_mu_; 148 std::unordered_map<CacheKey, RCReference<FunctionState>, CacheKeyHash> cache_ 149 TFRT_GUARDED_BY(cache_mu_); 150 }; 151 152 } // namespace tf 153 } // namespace tfrt 154 155 #endif // TENSORFLOW_CORE_TFRT_EAGER_FUNCTION_CACHE_H_ 156