• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_
16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_
17 
18 #include <assert.h>
19 #include <stddef.h>
20 
21 #include <memory>
22 #include <string>
23 #include <type_traits>
24 #include <utility>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/meta/type_traits.h"
29 #include "absl/types/span.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/device.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/node_properties.h"
39 #include "tensorflow/core/framework/op_def.pb.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/thread_annotations.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
49 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
50 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
51 #include "tensorflow/core/tfrt/utils/statusor.h"
52 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
53 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
54 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
55 #include "tfrt/host_context/chain.h"  // from @tf_runtime
56 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
57 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
58 #include "tfrt/host_context/sync_kernel_frame.h"  // from @tf_runtime
59 #include "tfrt/support/error_util.h"  // from @tf_runtime
60 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
61 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
62 
63 namespace tensorflow {
64 namespace tfd {
65 
66 class OpKernelRunner {
67  public:
68   static tfrt::StatusOr<OpKernelRunner> Create(
69       absl::string_view op_name, absl::string_view device_name, int num_args,
70       const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
71       const KernelFallbackCompatRequestState& fallback_request_state);
72 
Run(OpKernelContext * context)73   void Run(OpKernelContext* context) const {
74     DVLOG(1) << "KernelFallbackExecuteCompat Running Op: "
75              << op_kernel_->def().DebugString()
76              << ", on Device: " << device_->name();
77 
78     op_kernel_->Compute(context);
79   }
80 
81   void RunAsync(OpKernelContext* context,
82                 AsyncOpKernel::DoneCallback done_callback) const;
83 
IsAsync()84   bool IsAsync() const { return is_async_; }
85 
op_kernel()86   tensorflow::OpKernel* op_kernel() const { return op_kernel_.get(); }
device()87   tensorflow::Device* device() const { return device_; }
function_library_runtime()88   tensorflow::FunctionLibraryRuntime* function_library_runtime() const {
89     return function_library_runtime_;
90   }
resource_manager()91   tensorflow::ResourceMgr* resource_manager() const {
92     return resource_manager_;
93   }
94 
input_alloc_attrs()95   const gtl::InlinedVector<AllocatorAttributes, 4>& input_alloc_attrs() const {
96     return input_alloc_attrs_;
97   }
output_alloc_attrs()98   const gtl::InlinedVector<AllocatorAttributes, 1>& output_alloc_attrs() const {
99     return output_alloc_attrs_;
100   }
101 
102  private:
103   explicit OpKernelRunner(
104       tensorflow::Device* device,
105       tensorflow::FunctionLibraryRuntime* function_library_runtime,
106       std::unique_ptr<OpKernel> op_kernel);
107 
108   tensorflow::Device* device_ = nullptr;
109   tensorflow::FunctionLibraryRuntime* function_library_runtime_ = nullptr;
110   tensorflow::ResourceMgr* resource_manager_ = nullptr;
111   std::unique_ptr<OpKernel> op_kernel_;
112   bool is_async_ = false;
113   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
114   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
115 };
116 
117 class OpLocationKey {
118  public:
OpLocationKey(tfrt::Location loc)119   explicit OpLocationKey(tfrt::Location loc) : loc_(loc) {}
120 
121   template <typename H>
AbslHashValue(H h,const OpLocationKey & key)122   friend H AbslHashValue(H h, const OpLocationKey& key) {
123     // NOTE: Each BEF file has its own LocationHandler. Using LocationHandler
124     // as part of cache key here can avoid cache collision between different
125     // BEF file.
126     return H::combine(std::move(h), key.loc_.data, key.loc_.GetHandler());
127   }
128 
129   friend bool operator==(const OpLocationKey& x, const OpLocationKey& y) {
130     return x.loc_.data == y.loc_.data &&
131            x.loc_.GetHandler() == y.loc_.GetHandler();
132   }
133 
134  private:
135   tfrt::Location loc_;
136 };
137 
138 // OpKernelRunnerTable for keeping OpKernelRunner instances to avoid expensive
139 // reinstantiation of OpKernel and other repeated setup per kernel execution.
140 // OpKernelRunnerTable is thread-compatible.
141 class OpKernelRunnerTable {
142  public:
143   OpKernelRunnerTable() = default;
144 
145   // Return true if it successfully inserts `runner`. `index` is supposed to be
146   // dense.
Insert(int64_t index,OpKernelRunner runner)147   bool Insert(int64_t index, OpKernelRunner runner) {
148     if (runners_.size() <= index) runners_.resize(index + 1);
149     if (runners_[index].has_value()) return false;
150     runners_[index] = std::move(runner);
151     return true;
152   }
153 
154   // Return the OpKernelRunner at the corresponding `index` in the table. The
155   // result can never be nullptr. It is a fatal error to use an index that is
156   // not in the table. Note that the returned pointer will be invalidated if
157   // Insert() is called.
Get(int64_t index)158   const OpKernelRunner* Get(int64_t index) const {
159     DCHECK_GT(runners_.size(), index);
160     auto& result = runners_.at(index);
161     DCHECK(result.has_value());
162     return &(*result);
163   }
164 
165  private:
166   std::vector<absl::optional<OpKernelRunner>> runners_;
167 };
168 
169 // OpKernelRunnerCache is similar to OpKernelRunnerTable but thread-safe.
170 class OpKernelRunnerCache {
171  public:
172   OpKernelRunnerCache();
173 
174   tfrt::StatusOr<OpKernelRunner*> GetOrCreate(
175       tfrt::Location loc, absl::string_view op_name,
176       absl::string_view device_name, int num_args,
177       const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
178       const KernelFallbackCompatRequestState& fallback_request_state);
179 
180  private:
181   mutable mutex mu_;
182   absl::flat_hash_map<OpLocationKey, std::unique_ptr<OpKernelRunner>> map_
183       TF_GUARDED_BY(mu_);
184 };
185 
186 }  // namespace tfd
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_OP_KERNEL_RUNNER_H_
190