• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/framework/device.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/platform/refcount.h"
26 #include "tensorflow/core/platform/threadpool_interface.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
29 #include "tensorflow/core/tfrt/utils/model_metadata.h"
30 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
31 #include "tfrt/support/pointer_util.h"  // from @tf_runtime
32 
33 namespace tensorflow {
34 namespace tfd {
35 
36 class OpKernelRunnerCache;
37 class OpKernelRunnerTable;
38 
39 // FallbackResourceArray holds the tensors that are computed only once during
40 // initialization and read-only afterwards.
41 class FallbackResourceArray {
42  public:
43   // Set `tensor` in the array at `index`. `index` should be dense and duplicate
44   // indices are not allowed.
45   void SetResource(int index, tensorflow::tfrt_stub::ImmutableTensor tensor);
46 
47   // Get the resource tensor wrapped in AsyncValue value at `index`.
48   tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>*
GetResource(int index)49   GetResource(int index) const {
50     return resource_async_values_.at(index).get();
51   }
52 
53  private:
54   // `resources_` holds the ownership of all the resource tensors. Note that it
55   // may not be a one-to-one mapping between `resources_` and
56   // `resource_async_values_`.
57   std::vector<std::unique_ptr<tensorflow::tfrt_stub::ImmutableTensor>>
58       resources_;
59   // `resource_async_values_` holds the UnRefCountedAsyncValue of the fallback
60   // tensors that can be directly used by fallback kernels in the graph.
61   std::vector<std::unique_ptr<
62       tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>>
63       resource_async_values_;
64 };
65 
66 // Per-request state in kernel falllback compat mode.
67 class KernelFallbackCompatRequestState {
68  public:
69   // NOTE: This is the constructor for training.
70   KernelFallbackCompatRequestState(
71       const tensorflow::DeviceMgr* device_manager, int64_t step_id,
72       tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,
73       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
74       core::RefCountPtr<Rendezvous> rendezvous,
75       OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array,
76       tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
77       const absl::optional<tfrt::ModelMetadata>& model_metadata,
78       const tensorflow::ProcessFunctionLibraryRuntime* pflr);
79 
80   // NOTE: This is the constructor for inference.
81   KernelFallbackCompatRequestState(
82       const tensorflow::DeviceMgr* device_manager, int64_t step_id,
83       OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array,
84       tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
85       const absl::optional<tfrt::ModelMetadata>& model_metadata,
86       const tensorflow::ProcessFunctionLibraryRuntime* pflr);
87 
88   // Returns the user-specified custom device for this request. It is currently
89   // only used for configure per-request intra op threadpool.
custom_device()90   tensorflow::Device* custom_device() const { return custom_device_.get(); }
91 
step_container()92   ScopedStepContainer* step_container() const { return step_container_.get(); }
93 
device_manager()94   const tensorflow::DeviceMgr& device_manager() const {
95     return *device_manager_;
96   }
97 
98   const tensorflow::ProcessFunctionLibraryRuntime&
process_function_library_runtime()99   process_function_library_runtime() const {
100     return *pflr_;
101   }
102 
collective_executor()103   CollectiveExecutor* collective_executor() const {
104     return collective_executor_;
105   }
106 
runner_table()107   OpKernelRunnerTable* runner_table() const { return runner_table_; }
108 
resource_array()109   FallbackResourceArray* resource_array() const { return resource_array_; }
110 
runner()111   std::function<void(std::function<void()>)>* runner() const {
112     return default_runner_;
113   }
114 
cancellation_manager()115   CancellationManager* cancellation_manager() const {
116     return default_cancellation_manager_;
117   }
118 
rendezvous()119   RendezvousInterface* rendezvous() const { return rendezvous_.get(); }
120 
set_log_device_placement(bool log)121   void set_log_device_placement(bool log) { log_device_placement_ = log; }
log_device_placement()122   bool log_device_placement() const { return log_device_placement_; }
123 
intra_op_threadpool()124   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool() const {
125     return intra_op_threadpool_;
126   }
127 
session_metadata()128   const SessionMetadata& session_metadata() const { return session_metadata_; }
129 
130  private:
131   // Below are resources needed by current tensorflow.
132   std::function<void(std::function<void()>)>* default_runner_ = nullptr;
133   ::tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container_;
134   std::unique_ptr<tensorflow::Device> custom_device_;
135   std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle_;
136   CollectiveExecutor* collective_executor_ = nullptr;
137   core::RefCountPtr<Rendezvous> rendezvous_;
138   CancellationManager* default_cancellation_manager_ = nullptr;
139 
140   const tensorflow::DeviceMgr* device_manager_ = nullptr;
141 
142   // `runner_table` holds the prepopulated tensorflow::OpKernel instances for
143   // kernel fallback compat mode.
144   OpKernelRunnerTable* runner_table_ = nullptr;
145 
146   // Resource array is used for keeping static values in the runtime. It is
147   // accessed through tfrt_fallback_async.set_resource and
148   // tfrt_fallback_async.get_resource kernels.
149   FallbackResourceArray* resource_array_ = nullptr;
150 
151   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_;
152 
153   // Model metadata used for monitoring and tracing purpose.
154   SessionMetadata session_metadata_;
155 
156   const tensorflow::ProcessFunctionLibraryRuntime* pflr_;
157 
158   bool log_device_placement_ = false;
159 };
160 
161 }  // namespace tfd
162 }  // namespace tensorflow
163 
164 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
165