• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements RuntimeFallbackOpHandler, responsible for running TFRT
17 // ops on Tensorflow.
18 
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
20 
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
26 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
27 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
28 #include "tensorflow/core/runtime_fallback/util/type_util.h"
29 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/dispatch_utils.h"  // from @tf_runtime
31 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
32 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
33 #include "tfrt/core_runtime/op_invocation.h"  // from @tf_runtime
34 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
35 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
36 #include "tfrt/host_context/device.h"  // from @tf_runtime
37 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
38 #include "tfrt/host_context/kernel_utils.h"  // from @tf_runtime
39 #include "tfrt/support/error_util.h"  // from @tf_runtime
40 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
41 #include "tfrt/support/ref_count.h"  // from @tf_runtime
42 #include "tfrt/support/string_util.h"  // from @tf_runtime
43 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
44 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
45 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
46 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
47 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
48 // TODO(b/160798174): Avoid CUDA/ROCM macro.
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tfrt/gpu/device/device.h"  // from @tf_runtime
51 #include "tfrt/gpu/device/device_util.h"  // from @tf_runtime
52 #include "tfrt/gpu/tensor/dense_gpu_tensor.h"  // from @tf_runtime
53 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54 
55 namespace tensorflow {
56 namespace tfd {
57 // TODO(tfrt-devs): Rename it.
58 class RuntimeFallbackOpHandler : public tfrt::OpHandler {
59  public:
60   ~RuntimeFallbackOpHandler() override;
61 
62   llvm::Expected<tfrt::CoreRuntimeOp> MakeOp(
63       tfrt::string_view op_name) override;
64 
DeviceName() const65   tfrt::string_view DeviceName() const { return device_->name(); }
66 
TfDeviceName() const67   const std::string& TfDeviceName() const { return tf_device_name_; }
68 
GetDeviceRef()69   tfrt::RCReference<tfrt::Device> GetDeviceRef() { return device_.CopyRef(); }
70 
71  private:
72   explicit RuntimeFallbackOpHandler(tfrt::CoreRuntime* runtime,
73                                     tfrt::RCReference<tfrt::Device> device,
74                                     const std::string& tf_device_name);
75 
76   llvm::Error Initialize();
77 
78   friend llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
79       tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name);
80 
81   tfrt::RCReference<tfrt::Device> device_;
82   // Tensorflow device name, e.g., /device:CPU:0.
83   std::string tf_device_name_;
84 };
85 
86 namespace {
87 
88 using tfrt::AsyncValue;
89 using tfrt::AsyncValueRef;
90 using tfrt::Chain;
91 using tfrt::CoreRuntime;
92 using tfrt::CoreRuntimeOp;
93 using tfrt::DenseHostTensor;
94 using tfrt::ExecutionContext;
95 using tfrt::Expected;
96 using tfrt::OpAttrsRef;
97 using tfrt::OpHandler;
98 using tfrt::OpInvocation;
99 using tfrt::OpMetadataFn;
100 using tfrt::raw_ostream;
101 using tfrt::RCReference;
102 using tfrt::SmallVector;
103 using tfrt::string_view;
104 using tfrt::Tensor;
105 using tfrt::TensorMetadata;
106 
107 using RuntimeFallbackDispatchFn = AsyncValueRef<Chain> (*)(
108     const ExecutionContext& exec_ctx, const char* op_name,
109     const char* device_name, llvm::ArrayRef<Tensor*> arguments,
110     const OpAttrsRef& attrs,
111     llvm::MutableArrayRef<RCReference<AsyncValue>> results);
112 
113 struct RuntimeFallbackOpEntry {
114   std::string op_name;
115   OpMetadataFn metadata_fn = nullptr;
116   // All ops use the same dispatch function.
117   RuntimeFallbackDispatchFn dispatch_fn = &RuntimeFallbackExecute;
118 };
119 
GetDeviceFromFallbackTensor(const RuntimeFallbackTensor & result_tensor,const ExecutionContext & exec_ctx)120 static Expected<tfrt::RCReference<tfrt::Device>> GetDeviceFromFallbackTensor(
121     const RuntimeFallbackTensor& result_tensor,
122     const ExecutionContext& exec_ctx) {
123   tensorflow::Status status;
124   // Obtain the device. Please note that this device is probably not
125   // the device that the TensorHandle is located on. E.g. for a GPU resource
126   // its device is GPU but it is physicially located on CPU.
127   // We use this device because upper layer (e.g. distributed strategy) may
128   // use it for colocation. On the other hand, the actual device is not widely
129   // used in upper layers.
130   // In the future, if we need BackingDevice in higher layer as well, we can
131   // update c_api_tfrt layer to get it directly from tensorflow::TensorHandle.
132   const char* tf_device_name =
133       result_tensor.GetTensorHandle()->DeviceName(&status);
134   if (!status.ok()) {
135     return tfrt::MakeStringError(status.error_message());
136   }
137 
138   // TODO(b/165872892): Unify device name for tests.
139   auto device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
140       tf_device_name);
141   if (!device) {
142     // Convert device name to the short form, e.g. "GPU:0".
143     const char* tfrt_device_name =
144         ConvertTfDeviceNameToTfrtDefault(tf_device_name);
145     device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
146         tfrt_device_name);
147   }
148   assert(device);
149   return std::move(device);
150 }
151 
152 struct RuntimeFallbackOpHandlerTraits {
153   using InputTensorTy = Tensor;
154   using OpEntryTy = RuntimeFallbackOpEntry;
155   using OpHandlerInfoTy = RuntimeFallbackOpHandler*;
156 
Dispatchtensorflow::tfd::__anon030839700111::RuntimeFallbackOpHandlerTraits157   static void Dispatch(const RuntimeFallbackOpEntry& op_entry,
158                        RuntimeFallbackOpHandler* tf_op_handler,
159                        llvm::ArrayRef<Tensor*> inputs, const OpAttrsRef& attrs,
160                        llvm::ArrayRef<TensorMetadata> result_mds,
161                        llvm::MutableArrayRef<RCReference<AsyncValue>> results,
162                        AsyncValueRef<Chain>* chain,
163                        const ExecutionContext& exec_ctx) {
164     // Call RuntimeFallbackExecute.
165     auto ch = op_entry.dispatch_fn(exec_ctx, op_entry.op_name.c_str(),
166                                    tf_op_handler->TfDeviceName().c_str(),
167                                    inputs, attrs, results);
168 
169     if (chain) *chain = std::move(ch);
170   }
171 
172   static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
173                        tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon030839700111::RuntimeFallbackOpHandlerTraits174   GetResultDevice(RuntimeFallbackOpHandler* tf_op_handler,
175                   const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
176                   const ExecutionContext& exec_ctx) {
177     if (result_tensor_av.IsAvailable()) {
178       if (result_tensor_av.IsError()) {
179         return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
180             result_tensor_av.CopyRCRef());
181       }
182       auto expected_device = GetDeviceFromFallbackTensor(
183           result_tensor_av.get<RuntimeFallbackTensor>(), exec_ctx);
184       if (!expected_device) {
185         return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
186             tfrt::MakeErrorAsyncValueRef(
187                 exec_ctx.host(), tfrt::StrCat(expected_device.takeError())));
188       }
189       return std::move(expected_device.get());
190     }
191 
192     auto result_device =
193         tfrt::MakeUnconstructedAsyncValueRef<tfrt::RCReference<tfrt::Device>>(
194             exec_ctx.host());
195 
196     result_tensor_av.AndThen([result_tensor_av_ref = result_tensor_av.CopyRef(),
197                               result_device = result_device.CopyRef(),
198                               exec_ctx] {
199       assert(result_tensor_av_ref.IsAvailable());
200       if (result_tensor_av_ref.IsError()) {
201         result_device.SetError(result_tensor_av_ref.GetError());
202       }
203       auto expected_device = GetDeviceFromFallbackTensor(
204           result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx);
205       result_device.emplace(GetDeviceFromFallbackTensor(
206           result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx));
207     });
208     return std::move(result_device);
209   }
210 };
211 
212 }  // namespace
213 
MakeOp(string_view op_name)214 Expected<CoreRuntimeOp> RuntimeFallbackOpHandler::MakeOp(string_view op_name) {
215   // NOTE(fishx): Copying string here will cost extra overhead in graph
216   // execution. Because in current implementation, we needs to prepare the op
217   // before each executions.
218   // TODO(fishx): Avoid this heap allocation by getting op registration
219   // information from current TF.
220   RuntimeFallbackOpEntry op_entry;
221   if (!op_name.consume_front("tf."))
222     return tfrt::MakeStringError(op_name, " does not start with 'tf.'");
223   op_entry.op_name.assign(op_name.begin(), op_name.end());
224   return CoreRuntimeOp(
225       [op_entry = std::move(op_entry), this](const OpInvocation& invocation) {
226         // If the op does not have outputs, then it is expected to output an
227         // out chain.
228         bool update_chain = invocation.results.empty();
229 
230         // Convert the argument tensors to RuntimeFallbackTensors.
231         for (auto& argument : invocation.arguments) {
232           argument = argument.TransferToSameDevice(
233               invocation.exec_ctx, RuntimeFallbackTensor::kTensorType);
234         }
235 
236         tfrt::ExecuteOnOpHandler<RuntimeFallbackOpHandlerTraits>(
237             update_chain, invocation, std::move(op_entry), this);
238 
239 // TODO(b/160798174): Avoid CUDA/ROCM macro.
240 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
241         // If the RuntimeFallbackTensor contains a tensorflow::TensorHandle
242         // that holds a GPU tensor, convert it to tfrt::DenseGpuTensor, and
243         // populate the correct device name to the result tfrt::TensorHandle.
244         //
245         // Note that if the GPU tensor contains a DataType that is not natively
246         // supported by TFRT, e.g. Resource DataType, we skip the conversion.
247         //
248         // If the RuntimeFallbackTensor's tensorflow::TensorHandle holds a CPU
249         // tensor, do not convert it to DenseHostTensor (it will be lazily
250         // converted) for performance reason.
251         for (auto& result : invocation.results) {
252           auto* host_ctx = invocation.exec_ctx.host();
253           auto* result_tensor_av = result.GetAsyncTensor();
254 
255           if (!result_tensor_av->IsAvailable())
256             host_ctx->Await(FormRef(result_tensor_av));
257 
258           if (result_tensor_av->IsError()) continue;
259 
260           auto result_tensor_tf_th =
261               result_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
262 
263           // Check if we need to convert the RuntimeFallbackTensor.
264           if (!(IsGpuTensorHandle(*result_tensor_tf_th) &&
265                 IsSupportedByTFRTGpu(result_tensor_tf_th->DataType())))
266             continue;
267 
268           result = result.TransferToSameDevice(
269               invocation.exec_ctx, tfrt::gpu::DenseGpuTensor::kTensorType);
270         }
271 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
272       },
273       // device and arg_tensor_type are not used in runtime fallback ops.
274       /*is_fallback=*/true, /*device=*/device_.CopyRef());
275 }
276 
CreateRuntimeFallbackOpHandler(tfrt::CoreRuntime * runtime,tfrt::string_view tf_device_name)277 llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
278     tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name) {
279   // TODO(fishx): Remove the device field from fallback op handler.
280   std::unique_ptr<RuntimeFallbackOpHandler> op_handler(
281       new RuntimeFallbackOpHandler(
282           runtime, runtime->GetHostContext()->GetHostDeviceRef(),
283           tf_device_name.str()));
284   if (auto error = op_handler->Initialize()) {
285     return std::move(error);
286   }
287   auto op_handler_ptr = op_handler.get();
288   runtime->TakeOpHandler(std::move(op_handler));
289   return op_handler_ptr;
290 }
291 
RuntimeFallbackOpHandler(CoreRuntime * runtime,tfrt::RCReference<tfrt::Device> device,const std::string & tf_device_name)292 RuntimeFallbackOpHandler::RuntimeFallbackOpHandler(
293     CoreRuntime* runtime, tfrt::RCReference<tfrt::Device> device,
294     const std::string& tf_device_name)
295     : OpHandler("tf", runtime, nullptr),
296       device_(std::move(device)),
297       tf_device_name_(tf_device_name) {}
298 
~RuntimeFallbackOpHandler()299 RuntimeFallbackOpHandler::~RuntimeFallbackOpHandler() {}
300 
Initialize()301 llvm::Error RuntimeFallbackOpHandler::Initialize() {
302 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303   Status status = InjectTfGpuResources();
304   if (!status.ok()) {
305     return tfrt::MakeStringError(tfrt::StrCat("error injecting GPU resources: ",
306                                               status.error_message()));
307   }
308 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
309 
310   return llvm::Error::success();
311 }
312 
313 }  // namespace tfd
314 }  // namespace tensorflow
315