• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <optional>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/eager/context.h"
23 #include "tensorflow/core/common_runtime/renamed_device.h"
24 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
25 #include "tensorflow/core/framework/device.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/platform/threadpool_interface.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
31 #include "tfrt/support/pointer_util.h"  // from @tf_runtime
32 
33 namespace tensorflow {
34 namespace tfd {
35 
SetResource(int index,tensorflow::tfrt_stub::ImmutableTensor tensor)36 void FallbackResourceArray::SetResource(
37     int index, tensorflow::tfrt_stub::ImmutableTensor tensor) {
38   if (resource_async_values_.size() <= index) {
39     resource_async_values_.resize(index + 1);
40   }
41 
42   DCHECK(!resource_async_values_[index]);
43 
44   resources_.push_back(std::make_unique<tensorflow::tfrt_stub::ImmutableTensor>(
45       std::move(tensor)));
46 
47   resource_async_values_[index] = std::make_unique<
48       tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>(
49       resources_.back().get());
50 }
51 
GetDefaultRunner()52 static std::function<void(std::function<void()>)>* GetDefaultRunner() {
53   static auto* const default_runner =
54       new std::function<void(std::function<void()>)>(
55           [](const std::function<void()>& f) { f(); });
56   return default_runner;
57 }
58 
GetDefaultCancellationManager()59 static CancellationManager* GetDefaultCancellationManager() {
60   // TODO(b/167630926): Support cancellation by hooking up with TFRT's
61   // mechanism.
62   static auto* const default_cancellation_manager = new CancellationManager;
63   return default_cancellation_manager;
64 }
65 
KernelFallbackCompatRequestState(const tensorflow::DeviceMgr * device_manager,int64_t step_id,tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle,core::RefCountPtr<Rendezvous> rendezvous,OpKernelRunnerTable * runner_table,FallbackResourceArray * resource_array,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata,const tensorflow::ProcessFunctionLibraryRuntime * pflr)66 KernelFallbackCompatRequestState::KernelFallbackCompatRequestState(
67     const tensorflow::DeviceMgr* device_manager, int64_t step_id,
68     tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,
69     std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle,
70     core::RefCountPtr<Rendezvous> rendezvous, OpKernelRunnerTable* runner_table,
71     FallbackResourceArray* resource_array,
72     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
73     const absl::optional<tfrt::ModelMetadata>& model_metadata,
74     const tensorflow::ProcessFunctionLibraryRuntime* pflr)
75     : default_runner_(GetDefaultRunner()),
76       step_container_(std::move(step_container)),
77       collective_executor_handle_(std::move(collective_executor_handle)),
78       collective_executor_(collective_executor_handle_
79                                ? collective_executor_handle_->get()
80                                : nullptr),
81       rendezvous_(std::move(rendezvous)),
82       default_cancellation_manager_(GetDefaultCancellationManager()),
83       device_manager_(device_manager),
84       runner_table_(runner_table),
85       resource_array_(resource_array),
86       intra_op_threadpool_(user_intra_op_threadpool),
87       pflr_(pflr) {
88   DCHECK(device_manager_);
89   DCHECK(runner_table_);
90   DCHECK(resource_array_);
91   DCHECK(rendezvous_);
92 
93   // TODO(tfrt-devs): Support customizing non-CPU devices.
94   auto* device = device_manager_->HostCPU();
95   if (user_intra_op_threadpool != nullptr) {
96     custom_device_ = tensorflow::RenamedDevice::NewRenamedDevice(
97         device->name(), device, /*owns_underlying=*/false,
98         /*isolate_session_state=*/false, user_intra_op_threadpool);
99   }
100   if (model_metadata.has_value()) {
101     session_metadata_.set_name(model_metadata.value().name);
102     session_metadata_.set_version(model_metadata.value().version);
103   }
104 }
105 
KernelFallbackCompatRequestState(const tensorflow::DeviceMgr * device_manager,int64_t step_id,OpKernelRunnerTable * runner_table,FallbackResourceArray * resource_array,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<tfrt::ModelMetadata> & model_metadata,const tensorflow::ProcessFunctionLibraryRuntime * pflr)106 KernelFallbackCompatRequestState::KernelFallbackCompatRequestState(
107     const tensorflow::DeviceMgr* device_manager, int64_t step_id,
108     OpKernelRunnerTable* runner_table, FallbackResourceArray* resource_array,
109     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
110     const absl::optional<tfrt::ModelMetadata>& model_metadata,
111     const tensorflow::ProcessFunctionLibraryRuntime* pflr)
112     : KernelFallbackCompatRequestState(
113           device_manager, step_id,
114           // The following code is copied from
115           // third_party/tensorflow/core/common_runtime/direct_session.cc
116           tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
117               std::make_unique<ScopedStepContainer>(
118                   step_id,
119                   [step_id, device_manager](const std::string& name) {
120                     for (tensorflow::Device* device :
121                          device_manager->ListDevices()) {
122                       auto status = device->resource_manager()->Cleanup(name);
123                       (void)status;
124                       tensorflow::ScopedAllocatorMgr* sam =
125                           device->GetScopedAllocatorMgr();
126                       if (sam) sam->Cleanup(step_id);
127                     }
128                   })},
129           /*collective_executor=*/nullptr,
130           /*rendezvous=*/
131           core::RefCountPtr<RefCountedIntraProcessRendezvous>(
132               new RefCountedIntraProcessRendezvous(device_manager)),
133           runner_table, resource_array, user_intra_op_threadpool,
134           model_metadata, pflr) {}
135 
136 }  // namespace tfd
137 }  // namespace tensorflow
138