1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements RuntimeFallbackOpHandler, responsible for running TFRT
17 // ops on Tensorflow.
18
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
20
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
26 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
27 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
28 #include "tensorflow/core/runtime_fallback/util/type_util.h"
29 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
30 #include "tfrt/core_runtime/dispatch_utils.h" // from @tf_runtime
31 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
32 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
33 #include "tfrt/core_runtime/op_invocation.h" // from @tf_runtime
34 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
35 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
36 #include "tfrt/host_context/device.h" // from @tf_runtime
37 #include "tfrt/host_context/host_context.h" // from @tf_runtime
38 #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
39 #include "tfrt/support/error_util.h" // from @tf_runtime
40 #include "tfrt/support/forward_decls.h" // from @tf_runtime
41 #include "tfrt/support/ref_count.h" // from @tf_runtime
42 #include "tfrt/support/string_util.h" // from @tf_runtime
43 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
44 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
45 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
46 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
47 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
48 // TODO(b/160798174): Avoid CUDA/ROCM macro.
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tfrt/gpu/device/device.h" // from @tf_runtime
51 #include "tfrt/gpu/device/device_util.h" // from @tf_runtime
52 #include "tfrt/gpu/tensor/dense_gpu_tensor.h" // from @tf_runtime
53 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54
55 namespace tensorflow {
56 namespace tfd {
57 // TODO(tfrt-devs): Rename it.
58 class RuntimeFallbackOpHandler : public tfrt::OpHandler {
59 public:
60 ~RuntimeFallbackOpHandler() override;
61
62 llvm::Expected<tfrt::CoreRuntimeOp> MakeOp(
63 tfrt::string_view op_name) override;
64
DeviceName() const65 tfrt::string_view DeviceName() const { return device_->name(); }
66
TfDeviceName() const67 const std::string& TfDeviceName() const { return tf_device_name_; }
68
GetDeviceRef()69 tfrt::RCReference<tfrt::Device> GetDeviceRef() { return device_.CopyRef(); }
70
71 private:
72 explicit RuntimeFallbackOpHandler(tfrt::CoreRuntime* runtime,
73 tfrt::RCReference<tfrt::Device> device,
74 const std::string& tf_device_name);
75
76 llvm::Error Initialize();
77
78 friend llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
79 tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name);
80
81 tfrt::RCReference<tfrt::Device> device_;
82 // Tensorflow device name, e.g., /device:CPU:0.
83 std::string tf_device_name_;
84 };
85
86 namespace {
87
88 using tfrt::AsyncValue;
89 using tfrt::AsyncValueRef;
90 using tfrt::Chain;
91 using tfrt::CoreRuntime;
92 using tfrt::CoreRuntimeOp;
93 using tfrt::DenseHostTensor;
94 using tfrt::ExecutionContext;
95 using tfrt::Expected;
96 using tfrt::OpAttrsRef;
97 using tfrt::OpHandler;
98 using tfrt::OpInvocation;
99 using tfrt::OpMetadataFn;
100 using tfrt::raw_ostream;
101 using tfrt::RCReference;
102 using tfrt::SmallVector;
103 using tfrt::string_view;
104 using tfrt::Tensor;
105 using tfrt::TensorMetadata;
106
107 using RuntimeFallbackDispatchFn = AsyncValueRef<Chain> (*)(
108 const ExecutionContext& exec_ctx, const char* op_name,
109 const char* device_name, llvm::ArrayRef<Tensor*> arguments,
110 const OpAttrsRef& attrs,
111 llvm::MutableArrayRef<RCReference<AsyncValue>> results);
112
113 struct RuntimeFallbackOpEntry {
114 std::string op_name;
115 OpMetadataFn metadata_fn = nullptr;
116 // All ops use the same dispatch function.
117 RuntimeFallbackDispatchFn dispatch_fn = &RuntimeFallbackExecute;
118 };
119
GetDeviceFromFallbackTensor(const RuntimeFallbackTensor & result_tensor,const ExecutionContext & exec_ctx)120 static Expected<tfrt::RCReference<tfrt::Device>> GetDeviceFromFallbackTensor(
121 const RuntimeFallbackTensor& result_tensor,
122 const ExecutionContext& exec_ctx) {
123 tensorflow::Status status;
124 // Obtain the device. Please note that this device is probably not
125 // the device that the TensorHandle is located on. E.g. for a GPU resource
126 // its device is GPU but it is physicially located on CPU.
127 // We use this device because upper layer (e.g. distributed strategy) may
128 // use it for colocation. On the other hand, the actual device is not widely
129 // used in upper layers.
130 // In the future, if we need BackingDevice in higher layer as well, we can
131 // update c_api_tfrt layer to get it directly from tensorflow::TensorHandle.
132 const char* tf_device_name =
133 result_tensor.GetTensorHandle()->DeviceName(&status);
134 if (!status.ok()) {
135 return tfrt::MakeStringError(status.error_message());
136 }
137
138 // TODO(b/165872892): Unify device name for tests.
139 auto device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
140 tf_device_name);
141 if (!device) {
142 // Convert device name to the short form, e.g. "GPU:0".
143 const char* tfrt_device_name =
144 ConvertTfDeviceNameToTfrtDefault(tf_device_name);
145 device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
146 tfrt_device_name);
147 }
148 assert(device);
149 return std::move(device);
150 }
151
152 struct RuntimeFallbackOpHandlerTraits {
153 using InputTensorTy = Tensor;
154 using OpEntryTy = RuntimeFallbackOpEntry;
155 using OpHandlerInfoTy = RuntimeFallbackOpHandler*;
156
Dispatchtensorflow::tfd::__anon030839700111::RuntimeFallbackOpHandlerTraits157 static void Dispatch(const RuntimeFallbackOpEntry& op_entry,
158 RuntimeFallbackOpHandler* tf_op_handler,
159 llvm::ArrayRef<Tensor*> inputs, const OpAttrsRef& attrs,
160 llvm::ArrayRef<TensorMetadata> result_mds,
161 llvm::MutableArrayRef<RCReference<AsyncValue>> results,
162 AsyncValueRef<Chain>* chain,
163 const ExecutionContext& exec_ctx) {
164 // Call RuntimeFallbackExecute.
165 auto ch = op_entry.dispatch_fn(exec_ctx, op_entry.op_name.c_str(),
166 tf_op_handler->TfDeviceName().c_str(),
167 inputs, attrs, results);
168
169 if (chain) *chain = std::move(ch);
170 }
171
172 static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
173 tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon030839700111::RuntimeFallbackOpHandlerTraits174 GetResultDevice(RuntimeFallbackOpHandler* tf_op_handler,
175 const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
176 const ExecutionContext& exec_ctx) {
177 if (result_tensor_av.IsAvailable()) {
178 if (result_tensor_av.IsError()) {
179 return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
180 result_tensor_av.CopyRCRef());
181 }
182 auto expected_device = GetDeviceFromFallbackTensor(
183 result_tensor_av.get<RuntimeFallbackTensor>(), exec_ctx);
184 if (!expected_device) {
185 return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
186 tfrt::MakeErrorAsyncValueRef(
187 exec_ctx.host(), tfrt::StrCat(expected_device.takeError())));
188 }
189 return std::move(expected_device.get());
190 }
191
192 auto result_device =
193 tfrt::MakeUnconstructedAsyncValueRef<tfrt::RCReference<tfrt::Device>>(
194 exec_ctx.host());
195
196 result_tensor_av.AndThen([result_tensor_av_ref = result_tensor_av.CopyRef(),
197 result_device = result_device.CopyRef(),
198 exec_ctx] {
199 assert(result_tensor_av_ref.IsAvailable());
200 if (result_tensor_av_ref.IsError()) {
201 result_device.SetError(result_tensor_av_ref.GetError());
202 }
203 auto expected_device = GetDeviceFromFallbackTensor(
204 result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx);
205 result_device.emplace(GetDeviceFromFallbackTensor(
206 result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx));
207 });
208 return std::move(result_device);
209 }
210 };
211
212 } // namespace
213
MakeOp(string_view op_name)214 Expected<CoreRuntimeOp> RuntimeFallbackOpHandler::MakeOp(string_view op_name) {
215 // NOTE(fishx): Copying string here will cost extra overhead in graph
216 // execution. Because in current implementation, we needs to prepare the op
217 // before each executions.
218 // TODO(fishx): Avoid this heap allocation by getting op registration
219 // information from current TF.
220 RuntimeFallbackOpEntry op_entry;
221 if (!op_name.consume_front("tf."))
222 return tfrt::MakeStringError(op_name, " does not start with 'tf.'");
223 op_entry.op_name.assign(op_name.begin(), op_name.end());
224 return CoreRuntimeOp(
225 [op_entry = std::move(op_entry), this](const OpInvocation& invocation) {
226 // If the op does not have outputs, then it is expected to output an
227 // out chain.
228 bool update_chain = invocation.results.empty();
229
230 // Convert the argument tensors to RuntimeFallbackTensors.
231 for (auto& argument : invocation.arguments) {
232 argument = argument.TransferToSameDevice(
233 invocation.exec_ctx, RuntimeFallbackTensor::kTensorType);
234 }
235
236 tfrt::ExecuteOnOpHandler<RuntimeFallbackOpHandlerTraits>(
237 update_chain, invocation, std::move(op_entry), this);
238
239 // TODO(b/160798174): Avoid CUDA/ROCM macro.
240 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
241 // If the RuntimeFallbackTensor contains a tensorflow::TensorHandle
242 // that holds a GPU tensor, convert it to tfrt::DenseGpuTensor, and
243 // populate the correct device name to the result tfrt::TensorHandle.
244 //
245 // Note that if the GPU tensor contains a DataType that is not natively
246 // supported by TFRT, e.g. Resource DataType, we skip the conversion.
247 //
248 // If the RuntimeFallbackTensor's tensorflow::TensorHandle holds a CPU
249 // tensor, do not convert it to DenseHostTensor (it will be lazily
250 // converted) for performance reason.
251 for (auto& result : invocation.results) {
252 auto* host_ctx = invocation.exec_ctx.host();
253 auto* result_tensor_av = result.GetAsyncTensor();
254
255 if (!result_tensor_av->IsAvailable())
256 host_ctx->Await(FormRef(result_tensor_av));
257
258 if (result_tensor_av->IsError()) continue;
259
260 auto result_tensor_tf_th =
261 result_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
262
263 // Check if we need to convert the RuntimeFallbackTensor.
264 if (!(IsGpuTensorHandle(*result_tensor_tf_th) &&
265 IsSupportedByTFRTGpu(result_tensor_tf_th->DataType())))
266 continue;
267
268 result = result.TransferToSameDevice(
269 invocation.exec_ctx, tfrt::gpu::DenseGpuTensor::kTensorType);
270 }
271 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
272 },
273 // device and arg_tensor_type are not used in runtime fallback ops.
274 /*is_fallback=*/true, /*device=*/device_.CopyRef());
275 }
276
CreateRuntimeFallbackOpHandler(tfrt::CoreRuntime * runtime,tfrt::string_view tf_device_name)277 llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
278 tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name) {
279 // TODO(fishx): Remove the device field from fallback op handler.
280 std::unique_ptr<RuntimeFallbackOpHandler> op_handler(
281 new RuntimeFallbackOpHandler(
282 runtime, runtime->GetHostContext()->GetHostDeviceRef(),
283 tf_device_name.str()));
284 if (auto error = op_handler->Initialize()) {
285 return std::move(error);
286 }
287 auto op_handler_ptr = op_handler.get();
288 runtime->TakeOpHandler(std::move(op_handler));
289 return op_handler_ptr;
290 }
291
RuntimeFallbackOpHandler(CoreRuntime * runtime,tfrt::RCReference<tfrt::Device> device,const std::string & tf_device_name)292 RuntimeFallbackOpHandler::RuntimeFallbackOpHandler(
293 CoreRuntime* runtime, tfrt::RCReference<tfrt::Device> device,
294 const std::string& tf_device_name)
295 : OpHandler("tf", runtime, nullptr),
296 device_(std::move(device)),
297 tf_device_name_(tf_device_name) {}
298
~RuntimeFallbackOpHandler()299 RuntimeFallbackOpHandler::~RuntimeFallbackOpHandler() {}
300
Initialize()301 llvm::Error RuntimeFallbackOpHandler::Initialize() {
302 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303 Status status = InjectTfGpuResources();
304 if (!status.ok()) {
305 return tfrt::MakeStringError(tfrt::StrCat("error injecting GPU resources: ",
306 status.error_message()));
307 }
308 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
309
310 return llvm::Error::success();
311 }
312
313 } // namespace tfd
314 } // namespace tensorflow
315