• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file declares kernel utils.
17 
18 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
19 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
20 
21 #include <memory>
22 
23 #include "absl/strings/match.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "tensorflow/c/tf_tensor.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
35 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
36 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
37 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
38 #include "tfrt/support/error_util.h"  // from @tf_runtime
39 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
40 #include "tfrt/tensor/tensor_shape.h"  // from @tf_runtime
41 
42 namespace tensorflow {
43 namespace tfd {
44 
45 template <typename T>
46 struct AutoReleaser {
operatorAutoReleaser47   void operator()(T* p) const { p->Release(); }
48 };
49 template <typename T>
50 using AutoReleasePtr = std::unique_ptr<T, AutoReleaser<T>>;
51 
52 using OwnedEagerContext = AutoReleasePtr<EagerContext>;
53 using OwnedEagerOperation = AutoReleasePtr<EagerOperation>;
54 using OwnedTensorHandle = AutoReleasePtr<TensorHandle>;
55 using OwnedAbstractTensorInterface = AutoReleasePtr<AbstractTensorInterface>;
56 
57 // Check if a TensorHandle physically resides on GPU.
IsGpuTensorHandle(const tensorflow::TensorHandle & handle)58 inline bool IsGpuTensorHandle(const tensorflow::TensorHandle& handle) {
59   tensorflow::Status dummy_status;
60   // BackingDeviceName is where the tensor is physically located, not where the
61   // op that produces the tensor is.
62   // Note that dummy_status is never set in TensorHandle::BackingDeviceName.
63   absl::string_view device_name = handle.BackingDeviceName(&dummy_status);
64   return absl::StrContains(device_name, "GPU");
65 }
66 
67 // TODO(zhangqiaorjc): Allowlist more dtypes as tfrt GPU supports more.
68 // RuntimeFallbackTensor of supported dtypes below will be eagerly converted to
69 // tfrt::DenseGpuTensor after each RuntimeFallbackOpHandler::Execute.
IsSupportedByTFRTGpu(DataType dtype)70 inline bool IsSupportedByTFRTGpu(DataType dtype) {
71   switch (dtype) {
72     default:
73       return false;
74     case DataType::DT_FLOAT:
75     case DataType::DT_DOUBLE:
76     case DataType::DT_INT32:
77       return true;
78   }
79 }
80 
81 // TODO(b/165872892): Remove this method.
82 // This method is needed because we use different device name in TF-TFRT
83 // integration and mlir test. In TF-TFRT integration, we reuse the device full
84 // name (e.g. /job:localhost/replica:0/task:0/device:GPU:0) from TF. But in mlir
85 // test, we use simplified device name "GPU:0". And lot of things in fallback
86 // need to be used in both cases. As a result, we need to look up the device
87 // with both device names.
ConvertTfDeviceNameToTfrtDefault(const char * device_name)88 inline const char* ConvertTfDeviceNameToTfrtDefault(const char* device_name) {
89   assert(strlen(device_name) >= 5);
90   return &device_name[strlen(device_name) - 5];
91 }
92 
93 // Create and initialize EagerContext.
94 tfrt::Expected<OwnedEagerContext> InitEagerContext();
95 
96 tfrt::Expected<OwnedEagerContext> InitEagerContext(
97     DynamicDeviceMgr* device_mgr, const SessionOptions& session_opts,
98     ContextDevicePlacementPolicy default_device_placement_policy,
99     bool is_async);
100 
101 // Obtain EagerContext from ExecutionContext.
102 tfrt::Expected<EagerContext*> GetEagerContext(tfrt::ExecutionContext exec_ctx);
103 
104 // Return the CoreRuntimeOp for `op_name` using fallback op_handler.
105 llvm::Expected<tfrt::CoreRuntimeOp> GetFallbackOp(tfrt::string_view op_name,
106                                                   tfrt::HostContext* host);
107 
108 constexpr char kEagerContextResourceName[] = "EagerContextResourceName";
109 
110 class EagerContextResource {
111  public:
EagerContextResource()112   explicit EagerContextResource()
113       : device_mgr_(std::make_unique<DynamicDeviceMgr>()),
114         ctx_{InitEagerContext(
115             device_mgr_.get(), tensorflow::SessionOptions(),
116             tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
117             /*is_async=*/false)} {}
EagerContextResource(const SessionOptions & session_opts,ContextDevicePlacementPolicy default_device_placement_policy,bool is_async)118   explicit EagerContextResource(
119       const SessionOptions& session_opts,
120       ContextDevicePlacementPolicy default_device_placement_policy,
121       bool is_async)
122       : device_mgr_(std::make_unique<DynamicDeviceMgr>()),
123         ctx_{InitEagerContext(device_mgr_.get(), session_opts,
124                               default_device_placement_policy, is_async)} {}
125 
GetTFEagerContext()126   tfrt::Expected<EagerContext*> GetTFEagerContext() {
127     if (!ctx_) return ctx_.takeError();
128     return ctx_.get().get();
129   }
130 
GetDeviceMgr()131   DynamicDeviceMgr* GetDeviceMgr() { return device_mgr_.get(); }
132 
AddDevices(std::vector<std::unique_ptr<Device>> devices)133   llvm::Error AddDevices(std::vector<std::unique_ptr<Device>> devices) {
134     if (!ctx_) return ctx_.takeError();
135     Status s = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
136                    ctx_.get()->local_device_mgr())
137                    ->AddDevices(std::move(devices));
138     if (!s.ok()) return tfrt::MakeStringError(s.error_message());
139     ctx_.get()->InitPrioritizedDeviceTypeList();
140     ctx_.get()->pflr()->InitializeDeviceAndFlr();
141     return llvm::Error::success();
142   }
143 
144  private:
145   // EagerContext uses this device_mgs as local_device_mgr. We manage the
146   // device_mgr_ here to allow TFRT adding new devices after EagerContext
147   // initialization.
148   // Today, TFRT only adds TPU devices after EagerContext initialization.
149   std::unique_ptr<DynamicDeviceMgr> device_mgr_;
150 
151   tfrt::Expected<OwnedEagerContext> ctx_;
152 };
153 
154 }  // namespace tfd
155 }  // namespace tensorflow
156 
157 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_UTILS_H_
158