1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file declares kernel utils.
17
18 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
19 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
20
21 #include <memory>
22
23 #include "absl/strings/match.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "tensorflow/c/tf_tensor.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
35 #include "tfrt/dtype/dtype.h" // from @tf_runtime
36 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
37 #include "tfrt/host_context/host_context.h" // from @tf_runtime
38 #include "tfrt/support/error_util.h" // from @tf_runtime
39 #include "tfrt/support/forward_decls.h" // from @tf_runtime
40 #include "tfrt/tensor/tensor_shape.h" // from @tf_runtime
41
42 namespace tensorflow {
43 namespace tfd {
44
45 template <typename T>
46 struct AutoReleaser {
operatorAutoReleaser47 void operator()(T* p) const { p->Release(); }
48 };
49 template <typename T>
50 using AutoReleasePtr = std::unique_ptr<T, AutoReleaser<T>>;
51
52 using OwnedEagerContext = AutoReleasePtr<EagerContext>;
53 using OwnedEagerOperation = AutoReleasePtr<EagerOperation>;
54 using OwnedTensorHandle = AutoReleasePtr<TensorHandle>;
55 using OwnedAbstractTensorInterface = AutoReleasePtr<AbstractTensorInterface>;
56
57 // Check if a TensorHandle physically resides on GPU.
IsGpuTensorHandle(const tensorflow::TensorHandle & handle)58 inline bool IsGpuTensorHandle(const tensorflow::TensorHandle& handle) {
59 tensorflow::Status dummy_status;
60 // BackingDeviceName is where the tensor is physically located, not where the
61 // op that produces the tensor is.
62 // Note that dummy_status is never set in TensorHandle::BackingDeviceName.
63 absl::string_view device_name = handle.BackingDeviceName(&dummy_status);
64 return absl::StrContains(device_name, "GPU");
65 }
66
67 // TODO(zhangqiaorjc): Allowlist more dtypes as tfrt GPU supports more.
68 // RuntimeFallbackTensor of supported dtypes below will be eagerly converted to
69 // tfrt::DenseGpuTensor after each RuntimeFallbackOpHandler::Execute.
IsSupportedByTFRTGpu(DataType dtype)70 inline bool IsSupportedByTFRTGpu(DataType dtype) {
71 switch (dtype) {
72 default:
73 return false;
74 case DataType::DT_FLOAT:
75 case DataType::DT_DOUBLE:
76 case DataType::DT_INT32:
77 return true;
78 }
79 }
80
81 // TODO(b/165872892): Remove this method.
82 // This method is needed because we use different device name in TF-TFRT
83 // integration and mlir test. In TF-TFRT integration, we reuse the device full
84 // name (e.g. /job:localhost/replica:0/task:0/device:GPU:0) from TF. But in mlir
85 // test, we use simplified device name "GPU:0". And lot of things in fallback
86 // need to be used in both cases. As a result, we need to look up the device
87 // with both device names.
ConvertTfDeviceNameToTfrtDefault(const char * device_name)88 inline const char* ConvertTfDeviceNameToTfrtDefault(const char* device_name) {
89 assert(strlen(device_name) >= 5);
90 return &device_name[strlen(device_name) - 5];
91 }
92
93 // Create and initialize EagerContext.
94 tfrt::Expected<OwnedEagerContext> InitEagerContext();
95
96 tfrt::Expected<OwnedEagerContext> InitEagerContext(
97 DynamicDeviceMgr* device_mgr, const SessionOptions& session_opts,
98 ContextDevicePlacementPolicy default_device_placement_policy,
99 bool is_async);
100
101 // Obtain EagerContext from ExecutionContext.
102 tfrt::Expected<EagerContext*> GetEagerContext(tfrt::ExecutionContext exec_ctx);
103
104 // Return the CoreRuntimeOp for `op_name` using fallback op_handler.
105 llvm::Expected<tfrt::CoreRuntimeOp> GetFallbackOp(tfrt::string_view op_name,
106 tfrt::HostContext* host);
107
108 constexpr char kEagerContextResourceName[] = "EagerContextResourceName";
109
110 class EagerContextResource {
111 public:
EagerContextResource()112 explicit EagerContextResource()
113 : device_mgr_(std::make_unique<DynamicDeviceMgr>()),
114 ctx_{InitEagerContext(
115 device_mgr_.get(), tensorflow::SessionOptions(),
116 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
117 /*is_async=*/false)} {}
EagerContextResource(const SessionOptions & session_opts,ContextDevicePlacementPolicy default_device_placement_policy,bool is_async)118 explicit EagerContextResource(
119 const SessionOptions& session_opts,
120 ContextDevicePlacementPolicy default_device_placement_policy,
121 bool is_async)
122 : device_mgr_(std::make_unique<DynamicDeviceMgr>()),
123 ctx_{InitEagerContext(device_mgr_.get(), session_opts,
124 default_device_placement_policy, is_async)} {}
125
GetTFEagerContext()126 tfrt::Expected<EagerContext*> GetTFEagerContext() {
127 if (!ctx_) return ctx_.takeError();
128 return ctx_.get().get();
129 }
130
GetDeviceMgr()131 DynamicDeviceMgr* GetDeviceMgr() { return device_mgr_.get(); }
132
AddDevices(std::vector<std::unique_ptr<Device>> devices)133 llvm::Error AddDevices(std::vector<std::unique_ptr<Device>> devices) {
134 if (!ctx_) return ctx_.takeError();
135 Status s = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
136 ctx_.get()->local_device_mgr())
137 ->AddDevices(std::move(devices));
138 if (!s.ok()) return tfrt::MakeStringError(s.error_message());
139 ctx_.get()->InitPrioritizedDeviceTypeList();
140 ctx_.get()->pflr()->InitializeDeviceAndFlr();
141 return llvm::Error::success();
142 }
143
144 private:
145 // EagerContext uses this device_mgs as local_device_mgr. We manage the
146 // device_mgr_ here to allow TFRT adding new devices after EagerContext
147 // initialization.
148 // Today, TFRT only adds TPU devices after EagerContext initialization.
149 std::unique_ptr<DynamicDeviceMgr> device_mgr_;
150
151 tfrt::Expected<OwnedEagerContext> ctx_;
152 };
153
154 } // namespace tfd
155 } // namespace tensorflow
156
157 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
158