1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_ 16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_ 17 18 #include <assert.h> 19 #include <stddef.h> 20 21 #include <memory> 22 #include <string> 23 #include <type_traits> 24 #include <utility> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/container/inlined_vector.h" 28 #include "absl/meta/type_traits.h" 29 #include "absl/types/span.h" 30 #include "llvm/ADT/SmallVector.h" 31 #include "llvm/ADT/StringRef.h" 32 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 33 #include "tensorflow/core/framework/allocator.h" 34 #include "tensorflow/core/framework/cancellation.h" 35 #include "tensorflow/core/framework/device.h" 36 #include "tensorflow/core/framework/function.h" 37 #include "tensorflow/core/framework/node_def.pb.h" 38 #include "tensorflow/core/framework/node_properties.h" 39 #include "tensorflow/core/framework/op_def.pb.h" 40 #include "tensorflow/core/framework/op_kernel.h" 41 #include "tensorflow/core/framework/tensor.h" 42 #include "tensorflow/core/framework/types.h" 43 #include "tensorflow/core/platform/errors.h" 44 #include "tensorflow/core/platform/mutex.h" 45 #include "tensorflow/core/platform/status.h" 46 #include "tensorflow/core/platform/thread_annotations.h" 47 #include "tensorflow/core/platform/types.h" 48 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" 49 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h" 50 #include "tensorflow/core/runtime_fallback/util/attr_util.h" 51 #include "tensorflow/core/tfrt/utils/statusor.h" 52 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime 53 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime 54 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 55 #include "tfrt/host_context/chain.h" // from @tf_runtime 56 #include "tfrt/host_context/execution_context.h" // from @tf_runtime 57 #include "tfrt/host_context/host_context.h" // from @tf_runtime 58 #include "tfrt/host_context/sync_kernel_frame.h" // from @tf_runtime 59 #include "tfrt/support/error_util.h" // from @tf_runtime 60 #include "tfrt/support/forward_decls.h" // from @tf_runtime 61 #include "tfrt/tensor/tensor.h" // from @tf_runtime 62 63 namespace tensorflow { 64 namespace tfd { 65 66 class OpKernelRunner { 67 public: 68 static tfrt::StatusOr<OpKernelRunner> Create( 69 absl::string_view op_name, absl::string_view device_name, int num_args, 70 const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder, 71 const KernelFallbackCompatRequestState& fallback_request_state); 72 Run(OpKernelContext * context)73 void Run(OpKernelContext* context) const { 74 DVLOG(1) << "KernelFallbackExecuteCompat Running Op: " 75 << op_kernel_->def().DebugString() 76 << ", on Device: " << device_->name(); 77 78 op_kernel_->Compute(context); 79 } 80 81 void RunAsync(OpKernelContext* context, 82 AsyncOpKernel::DoneCallback done_callback) const; 83 IsAsync()84 bool IsAsync() const { return is_async_; } 85 op_kernel()86 tensorflow::OpKernel* op_kernel() const { return op_kernel_.get(); } device()87 tensorflow::Device* device() const { return device_; } function_library_runtime()88 tensorflow::FunctionLibraryRuntime* function_library_runtime() const { 89 return function_library_runtime_; 90 } resource_manager()91 tensorflow::ResourceMgr* resource_manager() const { 92 return resource_manager_; 93 } 94 input_alloc_attrs()95 const gtl::InlinedVector<AllocatorAttributes, 4>& input_alloc_attrs() const { 96 return input_alloc_attrs_; 97 } output_alloc_attrs()98 const gtl::InlinedVector<AllocatorAttributes, 1>& output_alloc_attrs() const { 99 return output_alloc_attrs_; 100 } 101 102 private: 103 explicit OpKernelRunner( 104 tensorflow::Device* device, 105 tensorflow::FunctionLibraryRuntime* function_library_runtime, 106 std::unique_ptr<OpKernel> op_kernel); 107 108 tensorflow::Device* device_ = nullptr; 109 tensorflow::FunctionLibraryRuntime* function_library_runtime_ = nullptr; 110 tensorflow::ResourceMgr* resource_manager_ = nullptr; 111 std::unique_ptr<OpKernel> op_kernel_; 112 bool is_async_ = false; 113 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 114 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 115 }; 116 117 class OpLocationKey { 118 public: OpLocationKey(tfrt::Location loc)119 explicit OpLocationKey(tfrt::Location loc) : loc_(loc) {} 120 121 template <typename H> AbslHashValue(H h,const OpLocationKey & key)122 friend H AbslHashValue(H h, const OpLocationKey& key) { 123 // NOTE: Each BEF file has its own LocationHandler. Using LocationHandler 124 // as part of cache key here can avoid cache collision between different 125 // BEF file. 126 return H::combine(std::move(h), key.loc_.data, key.loc_.GetHandler()); 127 } 128 129 friend bool operator==(const OpLocationKey& x, const OpLocationKey& y) { 130 return x.loc_.data == y.loc_.data && 131 x.loc_.GetHandler() == y.loc_.GetHandler(); 132 } 133 134 private: 135 tfrt::Location loc_; 136 }; 137 138 // OpKernelRunnerTable for keeping OpKernelRunner instances to avoid expensive 139 // reinstantiation of OpKernel and other repeated setup per kernel execution. 140 // OpKernelRunnerTable is thread-compatible. 141 class OpKernelRunnerTable { 142 public: 143 OpKernelRunnerTable() = default; 144 145 // Return true if it successfully inserts `runner`. `index` is supposed to be 146 // dense. Insert(int64_t index,OpKernelRunner runner)147 bool Insert(int64_t index, OpKernelRunner runner) { 148 if (runners_.size() <= index) runners_.resize(index + 1); 149 if (runners_[index].has_value()) return false; 150 runners_[index] = std::move(runner); 151 return true; 152 } 153 154 // Return the OpKernelRunner at the corresponding `index` in the table. The 155 // result can never be nullptr. It is a fatal error to use an index that is 156 // not in the table. Note that the returned pointer will be invalidated if 157 // Insert() is called. Get(int64_t index)158 const OpKernelRunner* Get(int64_t index) const { 159 DCHECK_GT(runners_.size(), index); 160 auto& result = runners_.at(index); 161 DCHECK(result.has_value()); 162 return &(*result); 163 } 164 165 private: 166 std::vector<absl::optional<OpKernelRunner>> runners_; 167 }; 168 169 // OpKernelRunnerCache is similar to OpKernelRunnerTable but thread-safe. 170 class OpKernelRunnerCache { 171 public: 172 OpKernelRunnerCache(); 173 174 tfrt::StatusOr<OpKernelRunner*> GetOrCreate( 175 tfrt::Location loc, absl::string_view op_name, 176 absl::string_view device_name, int num_args, 177 const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder, 178 const KernelFallbackCompatRequestState& fallback_request_state); 179 180 private: 181 mutable mutex mu_; 182 absl::flat_hash_map<OpLocationKey, std::unique_ptr<OpKernelRunner>> map_ 183 TF_GUARDED_BY(mu_); 184 }; 185 186 } // namespace tfd 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_ 190