1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
16
17 #include <functional>
18 #include <memory>
19 #include <optional>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/eager/context.h"
23 #include "tensorflow/core/common_runtime/renamed_device.h"
24 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
25 #include "tensorflow/core/framework/device.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/platform/threadpool_interface.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
31 #include "tfrt/support/pointer_util.h" // from @tf_runtime
32
33 namespace tensorflow {
34 namespace tfd {
35
SetResource(int index,tensorflow::tfrt_stub::ImmutableTensor tensor)36 void FallbackResourceArray::SetResource(
37 int index, tensorflow::tfrt_stub::ImmutableTensor tensor) {
38 if (resource_async_values_.size() <= index) {
39 resource_async_values_.resize(index + 1);
40 }
41
42 DCHECK(!resource_async_values_[index]);
43
44 resources_.push_back(std::make_unique<tensorflow::tfrt_stub::ImmutableTensor>(
45 std::move(tensor)));
46
47 resource_async_values_[index] = std::make_unique<
48 tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>(
49 resources_.back().get());
50 }
51
GetDefaultRunner()52 static std::function<void(std::function<void()>)>* GetDefaultRunner() {
53 static auto* const default_runner =
54 new std::function<void(std::function<void()>)>(
55 [](const std::function<void()>& f) { f(); });
56 return default_runner;
57 }
58
GetDefaultCancellationManager()59 static CancellationManager* GetDefaultCancellationManager() {
60 // TODO(b/167630926): Support cancellation by hooking up with TFRT's
61 // mechanism.
62 static auto* const default_cancellation_manager = new CancellationManager;
63 return default_cancellation_manager;
64 }
65
KernelFallbackCompatRequestState(const tensorflow::DeviceMgr * device_manager,int64_t step_id,tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle,core::RefCountPtr<Rendezvous> rendezvous,OpKernelRunnerTable * runner_table,FallbackResourceArray * resource_array,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata,const tensorflow::ProcessFunctionLibraryRuntime * pflr)66 KernelFallbackCompatRequestState::KernelFallbackCompatRequestState(
67 const tensorflow::DeviceMgr* device_manager, int64_t step_id,
68 tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,
69 std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle,
70 core::RefCountPtr<Rendezvous> rendezvous, OpKernelRunnerTable* runner_table,
71 FallbackResourceArray* resource_array,
72 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
73 const absl::optional<tfrt::ModelMetadata>& model_metadata,
74 const tensorflow::ProcessFunctionLibraryRuntime* pflr)
75 : default_runner_(GetDefaultRunner()),
76 step_container_(std::move(step_container)),
77 collective_executor_handle_(std::move(collective_executor_handle)),
78 collective_executor_(collective_executor_handle_
79 ? collective_executor_handle_->get()
80 : nullptr),
81 rendezvous_(std::move(rendezvous)),
82 default_cancellation_manager_(GetDefaultCancellationManager()),
83 device_manager_(device_manager),
84 runner_table_(runner_table),
85 resource_array_(resource_array),
86 intra_op_threadpool_(user_intra_op_threadpool),
87 pflr_(pflr) {
88 DCHECK(device_manager_);
89 DCHECK(runner_table_);
90 DCHECK(resource_array_);
91 DCHECK(rendezvous_);
92
93 // TODO(tfrt-devs): Support customizing non-CPU devices.
94 auto* device = device_manager_->HostCPU();
95 if (user_intra_op_threadpool != nullptr) {
96 custom_device_ = tensorflow::RenamedDevice::NewRenamedDevice(
97 device->name(), device, /*owns_underlying=*/false,
98 /*isolate_session_state=*/false, user_intra_op_threadpool);
99 }
100 if (model_metadata.has_value()) {
101 session_metadata_.set_name(model_metadata.value().name);
102 session_metadata_.set_version(model_metadata.value().version);
103 }
104 }
105
KernelFallbackCompatRequestState(const tensorflow::DeviceMgr * device_manager,int64_t step_id,OpKernelRunnerTable * runner_table,FallbackResourceArray * resource_array,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata,const tensorflow::ProcessFunctionLibraryRuntime * pflr)106 KernelFallbackCompatRequestState::KernelFallbackCompatRequestState(
107 const tensorflow::DeviceMgr* device_manager, int64_t step_id,
108 OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array,
109 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
110 const absl::optional<tfrt::ModelMetadata>& model_metadata,
111 const tensorflow::ProcessFunctionLibraryRuntime* pflr)
112 : KernelFallbackCompatRequestState(
113 device_manager, step_id,
114 // The following code is copied from
115 // third_party/tensorflow/core/common_runtime/direct_session.cc
116 tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
117 std::make_unique<ScopedStepContainer>(
118 step_id,
119 [step_id, device_manager](const std::string& name) {
120 for (tensorflow::Device* device :
121 device_manager->ListDevices()) {
122 auto status = device->resource_manager()->Cleanup(name);
123 (void)status;
124 tensorflow::ScopedAllocatorMgr* sam =
125 device->GetScopedAllocatorMgr();
126 if (sam) sam->Cleanup(step_id);
127 }
128 })},
129 /*collective_executor=*/nullptr,
130 /*rendezvous=*/
131 core::RefCountPtr<RefCountedIntraProcessRendezvous>(
132 new RefCountedIntraProcessRendezvous(device_manager)),
133 runner_table, resource_array, user_intra_op_threadpool,
134 model_metadata, pflr) {}
135
136 } // namespace tfd
137 } // namespace tensorflow
138