1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 17 18 #include <functional> 19 #include <memory> 20 21 #include "tensorflow/core/common_runtime/eager/context.h" 22 #include "tensorflow/core/framework/device.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/platform/refcount.h" 26 #include "tensorflow/core/platform/threadpool_interface.h" 27 #include "tensorflow/core/platform/types.h" 28 #include "tensorflow/core/tfrt/utils/fallback_tensor.h" 29 #include "tensorflow/core/tfrt/utils/model_metadata.h" 30 #include "tfrt/host_context/async_value.h" // from @tf_runtime 31 #include "tfrt/support/pointer_util.h" // from @tf_runtime 32 33 namespace tensorflow { 34 namespace tfd { 35 36 class OpKernelRunnerCache; 37 class OpKernelRunnerTable; 38 39 // FallbackResourceArray holds the tensors that are computed only once during 40 // initialization and read-only afterwards. 41 class FallbackResourceArray { 42 public: 43 // Set `tensor` in the array at `index`. `index` should be dense and duplicate 44 // indices are not allowed. 45 void SetResource(int index, tensorflow::tfrt_stub::ImmutableTensor tensor); 46 47 // Get the resource tensor wrapped in AsyncValue value at `index`. 48 tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>* GetResource(int index)49 GetResource(int index) const { 50 return resource_async_values_.at(index).get(); 51 } 52 53 private: 54 // `resources_` holds the ownership of all the resource tensors. Note that it 55 // may not be a one-to-one mapping between `resources_` and 56 // `resource_async_values_`. 57 std::vector<std::unique_ptr<tensorflow::tfrt_stub::ImmutableTensor>> 58 resources_; 59 // `resource_async_values_` holds the UnRefCountedAsyncValue of the fallback 60 // tensors that can be directly used by fallback kernels in the graph. 61 std::vector<std::unique_ptr< 62 tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>> 63 resource_async_values_; 64 }; 65 66 // Per-request state in kernel falllback compat mode. 67 class KernelFallbackCompatRequestState { 68 public: 69 // NOTE: This is the constructor for training. 70 KernelFallbackCompatRequestState( 71 const tensorflow::DeviceMgr* device_manager, int64_t step_id, 72 tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container, 73 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 74 core::RefCountPtr<Rendezvous> rendezvous, 75 OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array, 76 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, 77 const absl::optional<tfrt::ModelMetadata>& model_metadata, 78 const tensorflow::ProcessFunctionLibraryRuntime* pflr); 79 80 // NOTE: This is the constructor for inference. 81 KernelFallbackCompatRequestState( 82 const tensorflow::DeviceMgr* device_manager, int64_t step_id, 83 OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array, 84 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, 85 const absl::optional<tfrt::ModelMetadata>& model_metadata, 86 const tensorflow::ProcessFunctionLibraryRuntime* pflr); 87 88 // Returns the user-specified custom device for this request. It is currently 89 // only used for configure per-request intra op threadpool. custom_device()90 tensorflow::Device* custom_device() const { return custom_device_.get(); } 91 step_container()92 ScopedStepContainer* step_container() const { return step_container_.get(); } 93 device_manager()94 const tensorflow::DeviceMgr& device_manager() const { 95 return *device_manager_; 96 } 97 98 const tensorflow::ProcessFunctionLibraryRuntime& process_function_library_runtime()99 process_function_library_runtime() const { 100 return *pflr_; 101 } 102 collective_executor()103 CollectiveExecutor* collective_executor() const { 104 return collective_executor_; 105 } 106 runner_table()107 OpKernelRunnerTable* runner_table() const { return runner_table_; } 108 resource_array()109 FallbackResourceArray* resource_array() const { return resource_array_; } 110 runner()111 std::function<void(std::function<void()>)>* runner() const { 112 return default_runner_; 113 } 114 cancellation_manager()115 CancellationManager* cancellation_manager() const { 116 return default_cancellation_manager_; 117 } 118 rendezvous()119 RendezvousInterface* rendezvous() const { return rendezvous_.get(); } 120 set_log_device_placement(bool log)121 void set_log_device_placement(bool log) { log_device_placement_ = log; } log_device_placement()122 bool log_device_placement() const { return log_device_placement_; } 123 intra_op_threadpool()124 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool() const { 125 return intra_op_threadpool_; 126 } 127 session_metadata()128 const SessionMetadata& session_metadata() const { return session_metadata_; } 129 130 private: 131 // Below are resources needed by current tensorflow. 132 std::function<void(std::function<void()>)>* default_runner_ = nullptr; 133 ::tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container_; 134 std::unique_ptr<tensorflow::Device> custom_device_; 135 std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle_; 136 CollectiveExecutor* collective_executor_ = nullptr; 137 core::RefCountPtr<Rendezvous> rendezvous_; 138 CancellationManager* default_cancellation_manager_ = nullptr; 139 140 const tensorflow::DeviceMgr* device_manager_ = nullptr; 141 142 // `runner_table` holds the prepopulated tensorflow::OpKernel instances for 143 // kernel fallback compat mode. 144 OpKernelRunnerTable* runner_table_ = nullptr; 145 146 // Resource array is used for keeping static values in the runtime. It is 147 // accessed through tfrt_fallback_async.set_resource and 148 // tfrt_fallback_async.get_resource kernels. 149 FallbackResourceArray* resource_array_ = nullptr; 150 151 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_; 152 153 // Model metadata used for monitoring and tracing purpose. 154 SessionMetadata session_metadata_; 155 156 const tensorflow::ProcessFunctionLibraryRuntime* pflr_; 157 158 bool log_device_placement_ = false; 159 }; 160 161 } // namespace tfd 162 } // namespace tensorflow 163 164 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 165