• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
17 
18 #include <cstddef>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/immediate_execution_context.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tf_tensor_internal.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
30 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/casts.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/stringpiece.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/error_codes.pb.h"
51 #include "tensorflow/core/public/session_options.h"
52 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
53 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
54 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
55 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
56 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
57 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
58 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_interface.h"
59 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_registry.h"
60 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
61 #include "tensorflow/core/tfrt/eager/virtual_device.h"
62 #include "tensorflow/core/tfrt/utils/error_util.h"
63 #include "tensorflow/core/tfrt/utils/utils.h"
64 #include "tensorflow/core/util/device_name_utils.h"
65 #include "tfrt/common/compat/eigen/eigen_dtype.h"  // from @tf_runtime
66 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
67 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
68 #include "tfrt/core_runtime/op_attr_type.h"  // from @tf_runtime
69 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
70 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
71 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
72 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
73 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
74 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
75 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
76 #include "tfrt/host_context/chain.h"  // from @tf_runtime
77 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
78 #include "tfrt/host_context/device.h"  // from @tf_runtime
79 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
80 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
81 #include "tfrt/host_context/function.h"  // from @tf_runtime
82 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
83 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
84 #include "tfrt/host_context/location.h"  // from @tf_runtime
85 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
86 #include "tfrt/metrics/common_metrics.h"  // from @tf_runtime
87 #include "tfrt/support/error_util.h"  // from @tf_runtime
88 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
89 #include "tfrt/support/ref_count.h"  // from @tf_runtime
90 #include "tfrt/support/string_util.h"  // from @tf_runtime
91 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
92 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
93 #include "tfrt/tensor/dense_host_tensor_view.h"  // from @tf_runtime
94 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
95 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
96 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
97 #include "tfrt/tensor/tensor_serialize_utils.h"  // from @tf_runtime
98 #include "tfrt/tensor/tensor_type_registration.h"  // from @tf_runtime
99 
100 namespace tfrt {
101 namespace tf {
102 
103 namespace {
104 
105 using tensorflow::down_cast;
106 
107 constexpr char kGpuDeviceName[] = "GPU";
108 constexpr char kEnableNativeOpsAttr[] = "TFRT_TEST_enable_native_ops";
109 constexpr char kEnableGrapplerAttr[] = "TFRT_TEST_enable_grappler";
110 
CreateMetadata(DType dtype,absl::Span<const Index> dim_sizes)111 TensorMetadata CreateMetadata(DType dtype, absl::Span<const Index> dim_sizes) {
112   return TensorMetadata(
113       DType(dtype),
114       TensorShape(llvm::ArrayRef<Index>(
115           reinterpret_cast<const Index*>(dim_sizes.data()), dim_sizes.size())));
116 }
117 
ConvertDType(DType kind)118 tensorflow::DataType ConvertDType(DType kind) {
119   switch (kind) {
120     case DType::UI8:
121       return tensorflow::DT_UINT8;
122     case DType::UI16:
123       return tensorflow::DT_UINT16;
124     case DType::UI32:
125       return tensorflow::DT_UINT32;
126     case DType::UI64:
127       return tensorflow::DT_UINT64;
128     case DType::I8:
129       return tensorflow::DT_INT8;
130     case DType::I16:
131       return tensorflow::DT_INT16;
132     case DType::I32:
133       return tensorflow::DT_INT32;
134     case DType::I64:
135       return tensorflow::DT_INT64;
136     case DType::BF16:
137       return tensorflow::DT_BFLOAT16;
138     case DType::F16:
139       return tensorflow::DT_HALF;
140     case DType::F32:
141       return tensorflow::DT_FLOAT;
142     case DType::F64:
143       return tensorflow::DT_DOUBLE;
144     case DType::I1:
145       return tensorflow::DT_BOOL;
146     case DType::Complex64:
147       return tensorflow::DT_COMPLEX64;
148     case DType::Complex128:
149       return tensorflow::DT_COMPLEX128;
150     case DType::String:
151       return tensorflow::DT_STRING;
152     case DType::Resource:
153       return tensorflow::DT_RESOURCE;
154     case DType::Variant:
155       return tensorflow::DT_VARIANT;
156     case DType::QUI8:
157       return tensorflow::DT_QUINT8;
158     case DType::QUI16:
159       return tensorflow::DT_QUINT16;
160     case DType::QI8:
161       return tensorflow::DT_QINT8;
162     case DType::QI16:
163       return tensorflow::DT_QINT16;
164     case DType::QI32:
165       return tensorflow::DT_QINT32;
166     default:
167       LOG(ERROR) << "Unsupported kind " << kind;
168       return tensorflow::DT_INVALID;
169   }
170 }
171 
ConvertDType(tensorflow::DataType dtype)172 DType ConvertDType(tensorflow::DataType dtype) {
173   switch (dtype) {
174     case tensorflow::DT_UINT8:
175       return static_cast<DType>(DType::UI8);
176     case tensorflow::DT_UINT16:
177       return static_cast<DType>(DType::UI16);
178     case tensorflow::DT_UINT32:
179       return static_cast<DType>(DType::UI32);
180     case tensorflow::DT_UINT64:
181       return static_cast<DType>(DType::UI64);
182     case tensorflow::DT_INT8:
183       return static_cast<DType>(DType::I8);
184     case tensorflow::DT_INT16:
185       return static_cast<DType>(DType::I16);
186     case tensorflow::DT_INT32:
187       return static_cast<DType>(DType::I32);
188     case tensorflow::DT_INT64:
189       return static_cast<DType>(DType::I64);
190     case tensorflow::DT_BFLOAT16:
191       return static_cast<DType>(DType::BF16);
192     case tensorflow::DT_HALF:
193       return static_cast<DType>(DType::F16);
194     case tensorflow::DT_FLOAT:
195       return static_cast<DType>(DType::F32);
196     case tensorflow::DT_DOUBLE:
197       return static_cast<DType>(DType::F64);
198     case tensorflow::DT_BOOL:
199       return static_cast<DType>(DType::I1);
200     case tensorflow::DT_STRING:
201       return static_cast<DType>(DType::String);
202     case tensorflow::DT_COMPLEX64:
203       return static_cast<DType>(DType::Complex64);
204     case tensorflow::DT_COMPLEX128:
205       return static_cast<DType>(DType::Complex128);
206     case tensorflow::DT_RESOURCE:
207       return static_cast<DType>(DType::Resource);
208     case tensorflow::DT_VARIANT:
209       return static_cast<DType>(DType::Variant);
210     case tensorflow::DT_QUINT8:
211       return static_cast<DType>(DType::QUI8);
212     case tensorflow::DT_QUINT16:
213       return static_cast<DType>(DType::QUI16);
214     case tensorflow::DT_QINT8:
215       return static_cast<DType>(DType::QI8);
216     case tensorflow::DT_QINT16:
217       return static_cast<DType>(DType::QI16);
218     case tensorflow::DT_QINT32:
219       return static_cast<DType>(DType::QI32);
220     default:
221       LOG(FATAL) << "Unsupported dtype " << dtype;
222   }
223 }
224 
ConvertDTypeToOpAttrType(tensorflow::DataType dtype)225 OpAttrType ConvertDTypeToOpAttrType(tensorflow::DataType dtype) {
226   switch (dtype) {
227     case tensorflow::DT_UINT8:
228       return OpAttrType::UI8;
229     case tensorflow::DT_UINT16:
230       return OpAttrType::UI16;
231     case tensorflow::DT_UINT32:
232       return OpAttrType::UI32;
233     case tensorflow::DT_UINT64:
234       return OpAttrType::UI64;
235     case tensorflow::DT_INT8:
236       return OpAttrType::I8;
237     case tensorflow::DT_INT16:
238       return OpAttrType::I16;
239     case tensorflow::DT_INT32:
240       return OpAttrType::I32;
241     case tensorflow::DT_INT64:
242       return OpAttrType::I64;
243     case tensorflow::DT_BFLOAT16:
244       return OpAttrType::BF16;
245     case tensorflow::DT_HALF:
246       return OpAttrType::F16;
247     case tensorflow::DT_FLOAT:
248       return OpAttrType::F32;
249     case tensorflow::DT_DOUBLE:
250       return OpAttrType::F64;
251     case tensorflow::DT_BOOL:
252       return OpAttrType::BOOL;
253     case tensorflow::DT_COMPLEX64:
254       return OpAttrType::COMPLEX64;
255     case tensorflow::DT_COMPLEX128:
256       return OpAttrType::COMPLEX128;
257     default:
258       LOG(FATAL) << "Unsupported dtype " << dtype;
259   }
260 }
261 
262 // This method will first look at the calling op attrs and then look at the
263 // function def attrs to find the attribute value.
GetFuncAttr(const OpAttrs & op_attrs,const std::string & op_name,const tensorflow::FunctionLibraryDefinition & func_lib_def,string_view attr_name,bool * value)264 void GetFuncAttr(const OpAttrs& op_attrs, const std::string& op_name,
265                  const tensorflow::FunctionLibraryDefinition& func_lib_def,
266                  string_view attr_name, bool* value) {
267   bool success = op_attrs.Get(attr_name, value);
268   if (success) {
269     DVLOG(2) << "Caller explicitly specifies " << attr_name.str()
270              << (value ? "=true " : "=false, ");
271     return;
272   }
273 
274   const tensorflow::FunctionDef* function_def = func_lib_def.Find(op_name);
275   if (function_def == nullptr) {
276     return;
277   }
278 
279   tensorflow::Status status =
280       GetNodeAttr(tensorflow::AttrSlice(&function_def->attr()),
281                   {attr_name.data(), attr_name.size()}, value);
282   if (status.ok()) {
283     DVLOG(2) << "Function definition explicitly specifies " << attr_name.str()
284              << (value ? "=true" : "=false");
285     return;
286   }
287 }
288 
GetNextLocationId()289 int64_t GetNextLocationId() {
290   static std::atomic<int64_t> id(0);
291   return id.fetch_add(1, std::memory_order_relaxed);
292 }
293 }  // namespace
294 
Type() const295 tensorflow::DataType TensorInterface::Type() const {
296   auto kind = tensor_.get().metadata().dtype;
297   if (kind == DType::Unsupported) {
298     assert(llvm::isa<tensorflow::tfd::RuntimeFallbackTensor>(tensor_.get()));
299     return tensor_.get<tensorflow::tfd::RuntimeFallbackTensor>()
300         .GetTensorHandle()
301         ->DataType();
302   }
303   return ConvertDType(kind);
304 }
305 
NumDims() const306 int TensorInterface::NumDims() const { return tensor_.get().shape().GetRank(); }
307 
Dim(int dim_index) const308 int64_t TensorInterface::Dim(int dim_index) const {
309   return tensor_.get().shape().GetDimensionSize(dim_index);
310 }
311 
NumElements() const312 int64_t TensorInterface::NumElements() const {
313   if (!tensor_) {
314     return static_cast<int64_t>(tf_tensor_.NumElements());
315   }
316   return tensor_.get().shape().GetNumElements();
317 }
318 
ByteSize() const319 size_t TensorInterface::ByteSize() const {
320   return tensor_.get().metadata().GetHostSizeInBytes();
321 }
322 
Data() const323 void* TensorInterface::Data() const {
324   if (!tensor_) {
325     return tensorflow::TensorCApi::Buffer(tf_tensor_)->data();
326   } else {
327     auto& tensor = tensor_.get<DenseHostTensor>();
328     return tensor.data();
329   }
330 }
331 
332 // TFRT DenseHostTensor is always aligned
IsAligned() const333 bool TensorInterface::IsAligned() const { return true; }
334 
CanMove() const335 bool TensorInterface::CanMove() const {
336   // It is safe to move the Tensor if and only if we own the unique reference to
337   // the tensor buffer.
338   auto& dht = tensor_.get<DenseHostTensor>();
339   return tensor_.IsUnique() && dht.buffer()->IsUnique();
340 }
341 
SummarizeValue() const342 std::string TensorInterface::SummarizeValue() const {
343   if (!tensor_) {
344     return tf_tensor_.SummarizeValue(/*max_entries=*/3, /*print_v2=*/true);
345   } else {
346     std::string result;
347     llvm::raw_string_ostream result_ostream(result);
348     tensor_->Print(result_ostream);
349     return result;
350   }
351 }
352 
TensorRef() const353 AsyncValueRef<Tensor> TensorInterface::TensorRef() const {
354   return tensor_.CopyRef();
355 }
356 
TensorHandleInterface(Value && v,TfrtContext * context)357 TensorHandleInterface::TensorHandleInterface(Value&& v, TfrtContext* context)
358     : ImmediateExecutionTensorHandle(kTfrt),
359       context_(*context),
360       value_(std::move(v)) {}
361 
TensorHandleInterface(tensorflow::DataType dtype,Value && v,TfrtContext * context)362 TensorHandleInterface::TensorHandleInterface(tensorflow::DataType dtype,
363                                              Value&& v, TfrtContext* context)
364     : ImmediateExecutionTensorHandle(kTfrt),
365       dtype_(dtype),
366       context_(*context),
367       value_(std::move(v)) {}
368 
DataType() const369 tensorflow::DataType TensorHandleInterface::DataType() const {
370   // If dtype_ field is set, use it instead of waiting for the underlying
371   // TensorHandle's metadata to be available.
372   if (dtype_) {
373     return dtype_.getValue();
374   }
375   auto metadata = Metadata();
376   if (!metadata.hasValue()) {
377     LOG(ERROR)
378         << "Failed to get DataType due to error metadata: "
379         << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
380     return tensorflow::DT_INVALID;
381   }
382   auto kind = metadata.getValue()->dtype;
383   if (kind == DType::Unsupported) {
384     AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
385     if (!async_tensor->IsAvailable()) {
386       context_.GetHostContext()->Await(FormRef(async_tensor));
387     }
388 
389     if (async_tensor->IsError()) {
390       LOG(ERROR) << "Failed to get DataType from an error tensor "
391                  << async_tensor->GetError().message;
392       return tensorflow::DT_INVALID;
393     }
394     assert(async_tensor->IsType<tensorflow::tfd::RuntimeFallbackTensor>());
395     return async_tensor->get<tensorflow::tfd::RuntimeFallbackTensor>()
396         .GetTensorHandle()
397         ->DataType();
398   }
399   return ConvertDType(kind);
400 }
401 
TensorHandleStatus() const402 tensorflow::Status TensorHandleInterface::TensorHandleStatus() const {
403   if (context_.IsAsync()) {
404     return ::tensorflow::OkStatus();
405   } else {
406     auto metadata = Metadata();
407     if (!metadata.hasValue()) {
408       LOG(ERROR)
409           << "Metadata in the tensor handle is an error metadata: "
410           << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
411       return tensorflow::errors::Internal(
412           value_.get<TensorHandle>().GetAsyncMetadata().GetError().message);
413     }
414 
415     AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
416     if (!async_tensor->IsAvailable()) {
417       context_.GetHostContext()->Await(FormRef(async_tensor));
418     }
419 
420     if (async_tensor->IsError()) {
421       LOG(ERROR) << "Async tensor in the tensor handle is an error tensor: "
422                  << async_tensor->GetError().message;
423       return tensorflow::errors::Internal(async_tensor->GetError().message);
424     }
425 
426     return ::tensorflow::OkStatus();
427   }
428 }
429 
Shape(tensorflow::PartialTensorShape * shape) const430 tensorflow::Status TensorHandleInterface::Shape(
431     tensorflow::PartialTensorShape* shape) const {
432   auto metadata = Metadata();
433   if (!metadata.hasValue()) {
434     return CreateTfErrorStatus(
435         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
436   }
437   int num_dims = metadata.getValue()->shape.GetRank();
438   if (num_dims == -1) {
439     return ::tensorflow::OkStatus();
440   }
441   llvm::SmallVector<Index, 8> dims;
442   metadata.getValue()->shape.GetDimensions(&dims);
443   TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
444   return ::tensorflow::OkStatus();
445 }
446 
NumDims(int * num_dims) const447 tensorflow::Status TensorHandleInterface::NumDims(int* num_dims) const {
448   auto metadata = Metadata();
449   if (!metadata.hasValue()) {
450     return CreateTfErrorStatus(
451         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
452   }
453   *num_dims = metadata.getValue()->shape.GetRank();
454 
455   return ::tensorflow::OkStatus();
456 }
457 
NumElements(int64_t * num_elements) const458 tensorflow::Status TensorHandleInterface::NumElements(
459     int64_t* num_elements) const {
460   auto metadata = Metadata();
461   if (!metadata.hasValue()) {
462     return CreateTfErrorStatus(
463         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
464   }
465   *num_elements = metadata.getValue()->shape.GetNumElements();
466 
467   return ::tensorflow::OkStatus();
468 }
469 
Dim(int dim_index,int64_t * dim) const470 tensorflow::Status TensorHandleInterface::Dim(int dim_index,
471                                               int64_t* dim) const {
472   auto metadata = Metadata();
473   if (!metadata.hasValue()) {
474     return CreateTfErrorStatus(
475         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
476   }
477   *dim = metadata.getValue()->shape.GetDimensionSize(dim_index);
478 
479   return ::tensorflow::OkStatus();
480 }
481 
DeviceName(tensorflow::Status * status) const482 const char* TensorHandleInterface::DeviceName(
483     tensorflow::Status* status) const {
484   auto& th = value_.get<TensorHandle>();
485   if (!th.IsDeviceAvailable()) {
486     context_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
487   }
488   if (th.IsDeviceError()) {
489     *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
490     return nullptr;
491   }
492   return th.GetAvailableDevice()->name().data();
493 }
494 
BackingDeviceName(tensorflow::Status * status) const495 const char* TensorHandleInterface::BackingDeviceName(
496     tensorflow::Status* status) const {
497   return DeviceName(status);
498 }
499 
DeviceType(tensorflow::Status * status) const500 const char* TensorHandleInterface::DeviceType(
501     tensorflow::Status* status) const {
502   auto& th = value_.get<TensorHandle>();
503   if (!th.IsDeviceAvailable()) {
504     context_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
505   }
506   if (th.IsDeviceError()) {
507     *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
508     return nullptr;
509   }
510   return th.GetAvailableDevice()->type().name().data();
511 }
512 
Resolve(tensorflow::Status * status)513 tensorflow::AbstractTensorInterface* TensorHandleInterface::Resolve(
514     tensorflow::Status* status) {
515   auto* host_ctx = context_.GetHostContext();
516   auto host_device_ref = host_ctx->GetHostDeviceRef();
517   auto& th = value_.get<TensorHandle>();
518 
519   auto tensor_av = th.GetAsyncTensor();
520   if (!tensor_av->IsAvailable()) {
521     host_ctx->Await(FormRef(tensor_av));
522   }
523   if (auto* error = tensor_av->GetErrorIfPresent()) {
524     *status = CreateTfErrorStatus(*error);
525     return nullptr;
526   }
527   assert(th.IsMetadataAvailable());
528 
529   if (th.GetAsyncTensor()->get<Tensor>().tensor_type() ==
530       StringHostTensor::kTensorType) {
531     tensorflow::Tensor tf_tensor =
532         tensorflow::tfd::CopyShtToTfTensor(tensor_av->get<StringHostTensor>());
533     return new tensorflow::TensorInterface(tf_tensor);
534   }
535 
536   // Convert the tensor to DenseHostTensor.
537   auto req_ctx =
538       tfrt::RequestContextBuilder(host_ctx, context_.GetResourceContext())
539           .build();
540   if (!req_ctx) {
541     *status = tensorflow::Status(
542         tensorflow::error::Code::UNKNOWN,
543         StrCat("Failed to build a RequestContext: ", req_ctx.takeError()));
544     return nullptr;
545   }
546   tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
547   auto target_th = th.TransferTo(exec_ctx, std::move(host_device_ref),
548                                  DenseHostTensor::kTensorType);
549 
550   auto target_av = target_th.GetAsyncTensor();
551   if (!target_av->IsAvailable()) {
552     host_ctx->Await(FormRef(target_av));
553   }
554   if (target_av->IsError()) {
555     *status = tensorflow::Status(
556         tensorflow::error::Code::UNKNOWN,
557         StrCat("Cannot resolve tensor: ", target_av->GetError().message));
558     return nullptr;
559   }
560   auto host_tensor_ref = target_th.ReleaseTensorRef();
561   return new TensorInterface(std::move(host_tensor_ref));
562 }
563 
Metadata() const564 llvm::Optional<const TensorMetadata*> TensorHandleInterface::Metadata() const {
565   auto& th = value_.get<TensorHandle>();
566   if (!th.IsMetadataAvailable()) {
567     context_.GetHostContext()->Await(th.GetAsyncMetadata().CopyRCRef());
568   }
569   if (th.IsMetadataError()) {
570     return llvm::None;
571   }
572   return &th.GetAvailableMetadata();
573 }
574 
ContextInterface(const tensorflow::SessionOptions & opts,tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,bool is_async,bool use_tfrt_distributed_runtime)575 ContextInterface::ContextInterface(
576     const tensorflow::SessionOptions& opts,
577     tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
578     bool is_async, bool use_tfrt_distributed_runtime)
579     : ImmediateExecutionContext(kTfrt),
580       context_(opts, default_device_placement_policy, is_async),
581       use_tfrt_distributed_runtime_(use_tfrt_distributed_runtime) {
582   LOG(INFO) << "TFRT Enabled";
583   metrics::AddTFRTVersionMetric();
584 
585   op_handler_selector_ = std::make_unique<EagerOpHandlerSelector>(
586       GetCoreRuntime(), GetEagerContext(), GetFallbackOpHandler(),
587       GetEagerContext()->PinSmallOpsToCPU());
588 
589   run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
590 }
591 
~ContextInterface()592 ContextInterface::~ContextInterface() {}
593 
GetChain()594 AsyncValueRef<Chain>* ContextInterface::GetChain() {
595   auto thread_id = std::this_thread::get_id();
596   {
597     tensorflow::tf_shared_lock l(chain_map_mu_);
598     auto it = thread_local_chain_.find(thread_id);
599     if (it != thread_local_chain_.end()) {
600       return &it->second;
601     }
602   }
603   {
604     tensorflow::mutex_lock l(chain_map_mu_);
605     if (thread_local_chain_.find(thread_id) == thread_local_chain_.end()) {
606       auto chain = GetReadyChain();
607       thread_local_chain_[thread_id] = std::move(chain);
608     }
609     return &thread_local_chain_[thread_id];
610   }
611 }
612 
613 template <typename T>
MakeScalarTensor(T value,HostContext * host)614 static TensorInterface* MakeScalarTensor(T value, HostContext* host) {
615   // The TensorInterface implementation assumes the tensor is a DenseHostTensor,
616   // so we need to use a DenseHostTensor to represent a scalar tensor.
617   TensorMetadata md(GetDType<T>(), {});
618   auto t = DenseHostTensor::CreateUninitialized(md, host);
619   if (!t) {
620     LOG(ERROR) << "Failed to create DenseHostTensor";
621     return nullptr;
622   }
623   auto& dht = t.getValue();
624   MutableDHTArrayView<T> view{&dht};
625   view.Elements()[0] = value;
626 
627   return new TensorInterface(
628       MakeAvailableAsyncValueRef<DenseHostTensor>(host, std::move(dht)));
629 }
630 
CreateInt64Scalar(int64_t value)631 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt64Scalar(
632     int64_t value) {
633   return MakeScalarTensor(value, GetHostContext());
634 }
635 
CreateUint64Scalar(uint64_t value)636 tensorflow::AbstractTensorInterface* ContextInterface::CreateUint64Scalar(
637     uint64_t value) {
638   return MakeScalarTensor(value, GetHostContext());
639 }
640 
CreateInt32Scalar(int32_t value)641 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt32Scalar(
642     int32_t value) {
643   return MakeScalarTensor(value, GetHostContext());
644 }
645 
CreateFloatScalar(float value)646 tensorflow::AbstractTensorInterface* ContextInterface::CreateFloatScalar(
647     float value) {
648   return MakeScalarTensor(value, GetHostContext());
649 }
650 
CreateDoubleScalar(double value)651 tensorflow::AbstractTensorInterface* ContextInterface::CreateDoubleScalar(
652     double value) {
653   return MakeScalarTensor(value, GetHostContext());
654 }
655 
CreateHalfScalar(Eigen::half value)656 tensorflow::AbstractTensorInterface* ContextInterface::CreateHalfScalar(
657     Eigen::half value) {
658   return MakeScalarTensor(value, GetHostContext());
659 }
660 
CreateStringScalar(tensorflow::tstring value)661 tensorflow::AbstractTensorInterface* ContextInterface::CreateStringScalar(
662     tensorflow::tstring value) {
663   auto* host = GetHostContext();
664   TensorMetadata md(DType(DType::String), {});
665   auto t = StringHostTensor::MakeConstructedAsyncValueRef(md, host);
666   if (t.IsError()) {
667     LOG(ERROR) << "Failed to create StringHostTensor";
668     return nullptr;
669   }
670   t->strings()[0] = value;
671 
672   t.SetStateConcrete();
673   return new TensorInterface(std::move(t));
674 }
675 
CreateComplex128Scalar(tensorflow::complex128 value)676 tensorflow::AbstractTensorInterface* ContextInterface::CreateComplex128Scalar(
677     tensorflow::complex128 value) {
678   return MakeScalarTensor(value, GetHostContext());
679 }
680 
CreateBoolScalar(bool value)681 tensorflow::AbstractTensorInterface* ContextInterface::CreateBoolScalar(
682     bool value) {
683   return MakeScalarTensor(value, GetHostContext());
684 }
685 
CreateTensor(tensorflow::DataType dtype,absl::Span<const int64_t> dim_sizes)686 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
687     tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) {
688   std::vector<Index> dimvec(dim_sizes.size());
689   for (int i = 0; i < dim_sizes.size(); ++i) {
690     dimvec[i] = static_cast<int64_t>(dim_sizes[i]);
691   }
692 
693   TensorMetadata md;
694   switch (dtype) {
695     case tensorflow::DT_UINT8:
696       md = CreateMetadata(DType::UI8, dimvec);
697       break;
698     case tensorflow::DT_INT8:
699       md = CreateMetadata(DType::I8, dimvec);
700       break;
701     case tensorflow::DT_INT16:
702       md = CreateMetadata(DType::I16, dimvec);
703       break;
704     case tensorflow::DT_INT32:
705       md = CreateMetadata(DType::I32, dimvec);
706       break;
707     case tensorflow::DT_INT64:
708       md = CreateMetadata(DType::I64, dimvec);
709       break;
710     case tensorflow::DT_HALF:
711       md = CreateMetadata(DType::F16, dimvec);
712       break;
713     case tensorflow::DT_FLOAT:
714       md = CreateMetadata(DType::F32, dimvec);
715       break;
716     case tensorflow::DT_DOUBLE:
717       md = CreateMetadata(DType::F64, dimvec);
718       break;
719     case tensorflow::DT_BOOL:
720       md = CreateMetadata(DType::I1, dimvec);
721       break;
722     case tensorflow::DT_COMPLEX64:
723       md = CreateMetadata(DType::Complex64, dimvec);
724       break;
725     case tensorflow::DT_COMPLEX128:
726       md = CreateMetadata(DType::Complex128, dimvec);
727       break;
728     case tensorflow::DT_VARIANT:
729       // Note: TF Python API can create variant tensor for ragged tensor.
730       md = CreateMetadata(DType::Variant, dimvec);
731       break;
732     case tensorflow::DT_STRING:
733       // No TFRT Metadata needed for non-scalar string tensors.
734       break;
735     default:
736       LOG(ERROR) << "Cannot create tensor with dtype: " << dtype;
737       return nullptr;
738   }
739 
740   if (dtype == tensorflow::DT_STRING) {
741     // Create Tensorflow Tensor as a buffer for tstrings.
742     return new TensorInterface(
743         tensorflow::Tensor(dtype, tensorflow::TensorShape(dim_sizes)));
744   } else {
745     auto t = DenseHostTensor::CreateUninitialized(md, GetHostContext());
746     return new TensorInterface(MakeAvailableAsyncValueRef<DenseHostTensor>(
747         GetHostContext(), std::move(t.getValue())));
748   }
749 }
750 
CreateTensor(tensorflow::DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)751 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
752     tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
753     size_t len, MemoryReleaser memory_releaser, void* memory_releaser_arg) {
754   TensorMetadata metadata(ConvertDType(dtype),
755                           {dims, static_cast<size_t>(num_dims)});
756   RCReference<HostBuffer> buffer = HostBuffer::CreateFromExternal(
757       data, len,
758       [memory_releaser, memory_releaser_arg](void* data, size_t len) {
759         memory_releaser(data, len, memory_releaser_arg);
760       });
761   AsyncValueRef<DenseHostTensor> dht =
762       MakeConstructedAsyncValueRef<DenseHostTensor>(GetHostContext(), metadata,
763                                                     std::move(buffer));
764 
765   dht.SetStateConcrete();
766   return new TensorInterface(std::move(dht));
767 }
768 
UsesTFRT()769 bool ContextInterface::UsesTFRT() { return true; }
770 
CreateLocalHandle(tensorflow::AbstractTensorInterface * t)771 tensorflow::ImmediateExecutionTensorHandle* ContextInterface::CreateLocalHandle(
772     tensorflow::AbstractTensorInterface* t) {
773   auto* tensor_interface = down_cast<TensorInterface*>(t);
774   auto* host = GetHostContext();
775 
776   // Create RuntimeFallbackTensor from a TF Tensor, and then create
777   // the according TensorHandleInterface.
778   if (tensor_interface->IsTfTensor()) {
779     tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
780         tensorflow::TensorHandle::CreateLocalHandle(
781             tensor_interface->TfTensor())};
782 
783     auto expected_result_tensor =
784         tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
785             std::move(tf_tensor_handle), GetHostContext());
786 
787     if (expected_result_tensor) {
788       return new TensorHandleInterface(
789           Value(TensorHandle(
790               host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
791               MakeAvailableAsyncValueRef<
792                   tensorflow::tfd::RuntimeFallbackTensor>(
793                   host, std::move(expected_result_tensor.get())))),
794           GetTfrtContext());
795     } else {
796       return new TensorHandleInterface(
797           Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
798               GetHostContext(), StrCat(expected_result_tensor.takeError())))),
799           GetTfrtContext());
800     }
801   }
802 
803   auto tensor_av = tensor_interface->TensorRef();
804   const TensorMetadata& md = tensor_av.get<Tensor>().metadata();
805 
806   // NOTE(fishx): Following logic is needed to let TF-TFRT fully reach
807   // performance parity with current TF. This API is used to by tf.constant
808   // to convert Python object to **CPU** Tensor. tf.constant in current TF
809   // heavily depends on Tensor Mirroring feature for good performance. However,
810   // TFRT does not have Tensor Mirroring feature. In order to use Tensor
811   // Mirroring from current TF runtime, we convert the result of tf.constant to
812   // Fallback Tensor.
813 
814   if (tensor_av.IsAvailable()) {
815     if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
816       return new TensorHandleInterface(
817           Value(TensorHandle(
818               host->GetHostDeviceRef(), md,
819               MakeAvailableAsyncValueRef<
820                   tensorflow::tfd::RuntimeFallbackTensor>(
821                   host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
822                             *dht, host)))),
823           GetTfrtContext());
824     }
825   } else {
826     auto result_tensor = MakeIndirectAsyncValue(host);
827     tensor_av.AndThen([host, result_tensor = result_tensor,
828                        tensor_av = tensor_av.CopyRef()]() {
829       if (auto* dht =
830               llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
831         result_tensor->ForwardTo(
832             MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
833                 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
834                           *dht, host)));
835       } else {
836         result_tensor->ForwardTo(tensor_av.CopyRef());
837       }
838     });
839     return new TensorHandleInterface(
840         Value(TensorHandle(host->GetHostDeviceRef(), md,
841                            AsyncValueRef<Tensor>(std::move(result_tensor)))),
842         GetTfrtContext());
843   }
844   return new TensorHandleInterface(
845       Value(TensorHandle(host->GetHostDeviceRef(), md, std::move(tensor_av))),
846       GetTfrtContext());
847 }
848 
849 tensorflow::ImmediateExecutionTensorHandle*
CreateLocalHandleFromTFTensor(tensorflow::Tensor & t,const char * d_name)850 ContextInterface::CreateLocalHandleFromTFTensor(tensorflow::Tensor& t,
851                                                 const char* d_name) {
852   auto* host = GetHostContext();
853   // Create RuntimeFallbackTensor from a TF Tensor, and then create
854   // the according TensorHandleInterface.
855   tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
856       tensorflow::TensorHandle::CreateLocalHandle(std::move(t))};
857 
858   tfrt::Expected<tensorflow::tfd::RuntimeFallbackTensor>
859       expected_result_tensor =
860           tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
861               std::move(tf_tensor_handle), GetHostContext());
862 
863   if (expected_result_tensor) {
864     return new TensorHandleInterface(
865         Value(TensorHandle(
866             host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
867             MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
868                 host, std::move(expected_result_tensor.get())))),
869         GetTfrtContext());
870   } else {
871     return new TensorHandleInterface(
872         Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
873             GetHostContext(), StrCat(expected_result_tensor.takeError())))),
874         GetTfrtContext());
875   }
876 }
877 
878 tensorflow::ImmediateExecutionTensorHandle*
TFTensorHandleFromInterface(tensorflow::ImmediateExecutionTensorHandle * handle)879 ContextInterface::TFTensorHandleFromInterface(
880     tensorflow::ImmediateExecutionTensorHandle* handle) {
881   TensorHandle th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
882   AsyncValue* tensor_av = th.GetAsyncTensor();
883   if (tensor_av->IsUnavailable()) GetHostContext()->Await(FormRef(tensor_av));
884 
885   auto& tensor = th.GetAsyncTensor()->get<Tensor>();
886 
887   if (auto* rtfbt =
888           llvm::dyn_cast<tensorflow::tfd::RuntimeFallbackTensor>(&tensor))
889     return rtfbt->GetTensorHandle();
890 
891   if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(&tensor)) {
892     return tensorflow::TensorHandle::CreateLocalHandle(
893         tensorflow::tfd::MoveHostBufferToTfTensor(dht->buffer(), dht->dtype(),
894                                                   dht->shape()));
895   }
896 
897   if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(&tensor)) {
898     return tensorflow::TensorHandle::CreateLocalHandle(
899         tensorflow::tfd::CopyShtToTfTensor(*sht));
900   }
901 
902   LOG(ERROR) << "Unsupported tensor type";
903   return nullptr;
904 }
905 
CreateOperation()906 tensorflow::ImmediateExecutionOperation* ContextInterface::CreateOperation() {
907   return new OperationInterface(this);
908 }
909 
910 // TODO(srbs): Change this to directly fetch the MLIR function once that is
911 // supported.
RegisterFunction(tensorflow::AbstractFunction * f)912 tensorflow::Status ContextInterface::RegisterFunction(
913     tensorflow::AbstractFunction* f) {
914   tensorflow::FunctionDef* fdef;
915   TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
916   if (!fdef) {
917     return tensorflow::errors::InvalidArgument(
918         "GetFunctionDef returned nullptr.");
919   }
920   return AddFunctionDef(*fdef);
921 }
922 
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)923 void ContextInterface::ListDevices(
924     std::vector<tensorflow::DeviceAttributes>* devices) {
925   context_.GetEagerContext()->ListDevices(devices);
926 }
927 
AddDevices(std::vector<std::unique_ptr<tensorflow::Device>> devices)928 tensorflow::Status ContextInterface::AddDevices(
929     std::vector<std::unique_ptr<tensorflow::Device>> devices) {
930   if (!devices.empty() && devices[0]->device_type() != "CPU")
931     return tensorflow::errors::InvalidArgument(
932         "Device: ", devices[0]->device_type(), " is not allowed to be added ",
933         "after the context is initialized. Currently allowed device: CPU. ",
934         "May update this API to allow adding more types of devices.");
935 
936   for (const auto& d : devices) {
937     GetHostContext()->GetDeviceManager()->MaybeAddDevice(
938         TakeRef(new CpuDevice(d->name())));
939   }
940   TF_RETURN_IF_ERROR(GetEagerContext()->AddDevices(std::move(devices)));
941 
942   return ::tensorflow::OkStatus();
943 }
944 
ClearCachesAndThreadExecutors()945 void ContextInterface::ClearCachesAndThreadExecutors() {
946   GetEagerContext()->ClearCachesAndThreadExecutors();
947   GetHostContext()->Quiesce();
948 }
949 
StartStep()950 void ContextInterface::StartStep() { GetEagerContext()->StartStep(); }
951 
EndStep()952 void ContextInterface::EndStep() { GetEagerContext()->EndStep(); }
953 
EnableCollectiveOps(const tensorflow::ServerDef & server_def)954 tensorflow::Status ContextInterface::EnableCollectiveOps(
955     const tensorflow::ServerDef& server_def) {
956   if (use_tfrt_distributed_runtime_) {
957     return distributed_manager_->EnableCollectiveOps(server_def);
958   }
959   // Preserve the local virtual device names, since local virtual devices are
960   // added by TFRT and we need to add it back after worker server is
961   // initialized. Currently one such use case is the TPU_SYSTEM device, which
962   // is a virtual device specifically used to initialize TPUs.
963   std::vector<std::string> virtual_device_names;
964 
965   for (const auto& d :
966        GetHostContext()->GetDeviceManager()->ListDevices<Device>()) {
967     if (d->IsDeviceType(tfrt::VirtualDevice::kDeviceType)) {
968       tensorflow::DeviceNameUtils::ParsedName p;
969       if (!tensorflow::DeviceNameUtils::ParseFullName(d->name().str(), &p)) {
970         return tensorflow::errors::InvalidArgument(
971             "Invalid local virtual device name: ", d->name().str());
972       }
973 
974       virtual_device_names.push_back(tensorflow::DeviceNameUtils::FullName(
975           server_def.job_name(), /*replica=*/0, server_def.task_index(), p.type,
976           p.id));
977     }
978   }
979 
980   TF_RETURN_IF_ERROR(GetEagerContext()->EnableCollectiveOps(server_def));
981 
982   // Create new devices with updated device name.
983   std::vector<std::unique_ptr<tensorflow::Device>> dummy_tf_devices;
984   CreateDummyTfDevices(virtual_device_names, &dummy_tf_devices);
985 
986   std::string name_prefix =
987       absl::StrCat("/job:", server_def.job_name(),
988                    "/replica:0/task:", server_def.task_index());
989 
990   // Update host device in TFRT HostContext.
991   GetHostContext()->ResetHostDevice(
992       GetHostContext()
993           ->GetDeviceManager()
994           ->MaybeAddDevice(TakeRef(
995               new CpuDevice(absl::StrCat(name_prefix, "/device:CPU:0"))))
996           .release());
997 
998   // Update virtual devices in TFRT HostContext.
999   AddDummyTfrtDevices(virtual_device_names, GetHostContext());
1000 
1001   // Update eager context's device manager.
1002   auto* local_device_mgr = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
1003       GetEagerContext()->local_device_mgr());
1004   TF_RETURN_IF_ERROR(local_device_mgr->AddDevices(std::move(dummy_tf_devices)));
1005 
1006   return ::tensorflow::OkStatus();
1007 }
1008 
BuildFunctionRequestContext(tensorflow::tfrt_stub::OpKernelRunnerTable * runner_table,RCReference<tfrt::RequestContext> * request_context)1009 tensorflow::Status ContextInterface::BuildFunctionRequestContext(
1010     tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
1011     RCReference<tfrt::RequestContext>* request_context) {
1012   auto* step_container = GetEagerContext()->StepContainer();
1013   RequestContextBuilder request_context_builder(
1014       GetHostContext(), GetResourceContext(), step_container->StepId());
1015 
1016   TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
1017       &request_context_builder, runner_table, GetEagerContext()));
1018   if (distributed_manager_ != nullptr) {
1019     down_cast<DistributedManagerContextInterface*>(distributed_manager_.get())
1020         ->UpdateRequestContextBuilder(&request_context_builder);
1021   }
1022   auto expected_request_context = std::move(request_context_builder).build();
1023   if (!expected_request_context) {
1024     return tensorflow::errors::Internal(
1025         StrCat(expected_request_context.takeError()));
1026   }
1027   *request_context = std::move(expected_request_context.get());
1028   return ::tensorflow::OkStatus();
1029 }
1030 
BuildOpRequestContext(RCReference<tfrt::RequestContext> * request_context)1031 tensorflow::Status ContextInterface::BuildOpRequestContext(
1032     RCReference<tfrt::RequestContext>* request_context) {
1033   return BuildFunctionRequestContext(/*runner_table=*/nullptr, request_context);
1034 }
1035 
1036 tensorflow::ImmediateExecutionTensorHandle*
CopyTensorHandleToDevice(tensorflow::ImmediateExecutionTensorHandle * handle,const char * device_name,tensorflow::Status * status)1037 ContextInterface::CopyTensorHandleToDevice(
1038     tensorflow::ImmediateExecutionTensorHandle* handle, const char* device_name,
1039     tensorflow::Status* status) {
1040   auto* host_ctx = GetHostContext();
1041 
1042   TensorHandle src_th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
1043 
1044   auto tfrt_device_name =
1045       ConvertTfDeviceNameToTfrt(device_name, GetEagerContext());
1046   if (!tfrt_device_name) {
1047     *status = tensorflow::errors::InvalidArgument(
1048         StrCat(tfrt_device_name.takeError()));
1049     RCReference<AsyncValue> error_av =
1050         MakeErrorAsyncValueRef(host_ctx, status->error_message());
1051     return new TensorHandleInterface(
1052         Value(TensorHandle::CreateError(std::move(error_av))),
1053         GetTfrtContext());
1054   }
1055   auto dst_device_ref = host_ctx->GetDeviceManager()->GetDeviceRef<Device>(
1056       tfrt_device_name.get());
1057   if (!dst_device_ref) {
1058     std::string error_message =
1059         tfrt::StrCat("Failed to find destination device with name: ",
1060                      tfrt_device_name.get());
1061     *status = tensorflow::errors::Internal(error_message);
1062     RCReference<AsyncValue> error_av =
1063         MakeErrorAsyncValueRef(host_ctx, error_message);
1064     return new TensorHandleInterface(
1065         Value(TensorHandle::CreateError(std::move(error_av))),
1066         GetTfrtContext());
1067   }
1068 
1069   RCReference<RequestContext> request_ctx;
1070   *status = BuildOpRequestContext(&request_ctx);
1071   if (!status->ok()) return nullptr;
1072 
1073   ExecutionContext exec_ctx{std::move(request_ctx)};
1074 
1075   auto target_th =
1076       src_th.TransferToInferredType(exec_ctx, std::move(dst_device_ref));
1077 
1078   auto target_av = target_th.GetAsyncTensor();
1079   if (target_av->IsError()) {
1080     *status = tensorflow::errors::Internal(
1081         tfrt::StrCat("Copying to device <", tfrt_device_name.get(),
1082                      "> failed: ", target_av->GetError().message));
1083     return nullptr;
1084   }
1085   return new TensorHandleInterface(Value(target_th.CopyRef()),
1086                                    GetTfrtContext());
1087 }
1088 
AddFunctionDef(const tensorflow::FunctionDef & fdef)1089 tensorflow::Status ContextInterface::AddFunctionDef(
1090     const tensorflow::FunctionDef& fdef) {
1091   return GetEagerContext()->AddFunctionDef(fdef);
1092 }
1093 
AddFunctionDefWithStackTraces(const tensorflow::FunctionDef & fdef,const tensorflow::StackTracesMap & stack_traces)1094 tensorflow::Status ContextInterface::AddFunctionDefWithStackTraces(
1095     const tensorflow::FunctionDef& fdef,
1096     const tensorflow::StackTracesMap& stack_traces) {
1097   return GetEagerContext()->AddFunctionDefWithStackTraces(fdef, stack_traces);
1098 }
1099 
ListFunctionNames()1100 std::vector<std::string> ContextInterface::ListFunctionNames() {
1101   return GetEagerContext()->ListFunctionNames();
1102 }
1103 
RemoveFunction(const std::string & func)1104 tensorflow::Status ContextInterface::RemoveFunction(const std::string& func) {
1105   // TODO(tfrt-devs): We need to ensure all invocations of this function is
1106   // finished before removing it.
1107   function_cache_.RemoveFunction(func);
1108   return GetEagerContext()->RemoveFunction(func);
1109 }
1110 
FindFunctionDef(const std::string & name) const1111 const tensorflow::FunctionDef* ContextInterface::FindFunctionDef(
1112     const std::string& name) const {
1113   return GetEagerContext()->FindFunctionDef(name);
1114 }
1115 
1116 const tensorflow::DeviceNameUtils::ParsedName&
HostCPUParsedName() const1117 ContextInterface::HostCPUParsedName() const {
1118   return context_.HostCPUParsedName();
1119 }
1120 
HostCPUName() const1121 const std::string& ContextInterface::HostCPUName() const {
1122   return context_.GetEagerContext()->HostCPUName();
1123 }
1124 
1125 tensorflow::CustomDeviceOpHandler&
GetCustomDeviceOpHandler()1126 ContextInterface::GetCustomDeviceOpHandler() {
1127   return context_.GetEagerContext()->GetCustomDeviceOpHandler();
1128 }
1129 
RegisterCustomDevice(const std::string & name,std::unique_ptr<tensorflow::CustomDevice> device)1130 tensorflow::Status ContextInterface::RegisterCustomDevice(
1131     const std::string& name, std::unique_ptr<tensorflow::CustomDevice> device) {
1132   return context_.GetEagerContext()->RegisterCustomDevice(name,
1133                                                           std::move(device));
1134 }
1135 
FuncLibDef()1136 tensorflow::FunctionLibraryDefinition* ContextInterface::FuncLibDef() {
1137   return context_.GetEagerContext()->FuncLibDef();
1138 }
1139 
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)1140 void ContextInterface::SetReuseRendezvousForFunctions(
1141     bool reuse_rendezvous_for_functions) {
1142   // TODO(fishx): This feature doesn't work properly in TFRT yet. Fix it.
1143   context_.GetEagerContext()->SetReuseRendezvousForFunctions(
1144       reuse_rendezvous_for_functions);
1145 }
1146 
ResetGlobalRendezvousForFunction()1147 void ContextInterface::ResetGlobalRendezvousForFunction() {
1148   context_.GetEagerContext()->ResetGlobalRendezvousForFunction();
1149 }
1150 
GetLoggedOpsTestonly()1151 std::vector<std::string> ContextInterface::GetLoggedOpsTestonly() {
1152   const auto& ret = GetHostContext()
1153                         ->GetOrCreateSharedContext<tensorflow::tfd::OpLogger>()
1154                         .GetLoggedOps();
1155   return std::vector<std::string>(ret.begin(), ret.end());
1156 }
1157 
GetHostContext()1158 HostContext* ContextInterface::GetHostContext() {
1159   return GetCoreRuntime()->GetHostContext();
1160 }
1161 
GetEagerContext()1162 tensorflow::EagerContext* ContextInterface::GetEagerContext() {
1163   return context_.GetEagerContext();
1164 }
1165 
GetEagerContext() const1166 const tensorflow::EagerContext* ContextInterface::GetEagerContext() const {
1167   return context_.GetEagerContext();
1168 }
1169 
GetCoreRuntime()1170 CoreRuntime* ContextInterface::GetCoreRuntime() {
1171   return context_.GetCoreRuntime();
1172 }
1173 
GetTfrtContext()1174 TfrtContext* ContextInterface::GetTfrtContext() { return &context_; }
1175 
GetFallbackOpHandler()1176 OpHandler* ContextInterface::GetFallbackOpHandler() {
1177   return context_.GetFallbackOpHandler();
1178 }
1179 
GetResourceContext()1180 ResourceContext* ContextInterface::GetResourceContext() {
1181   return context_.GetResourceContext();
1182 }
1183 
SelectOpHandlerFromArguments(const tensorflow::ImmediateExecutionOperation & op,OpHandler ** op_handler)1184 tensorflow::Status ContextInterface::SelectOpHandlerFromArguments(
1185     const tensorflow::ImmediateExecutionOperation& op, OpHandler** op_handler) {
1186   return op_handler_selector_->SelectFromArguments(op, op_handler);
1187 }
1188 
SelectOpHandlerFromNodeDef(const tensorflow::ImmediateExecutionOperation & op,const NodeDef * node_def,OpHandler ** op_handler)1189 tensorflow::Status ContextInterface::SelectOpHandlerFromNodeDef(
1190     const tensorflow::ImmediateExecutionOperation& op, const NodeDef* node_def,
1191     OpHandler** op_handler) {
1192   return op_handler_selector_->SelectFromNodeDef(op, node_def, op_handler);
1193 }
1194 
ExportRunMetadata()1195 std::unique_ptr<tensorflow::RunMetadata> ContextInterface::ExportRunMetadata() {
1196   mutex_lock l(run_metadata_mu_);
1197 
1198   // NOTE(fishx): We need to merge run_metadata from TF Eager Context because
1199   // right now we still use current TF runtime to execute graph (e.g. tf.data
1200   // via fallback).
1201   auto result = GetEagerContext()->ExportRunMetadata();
1202   result->MergeFrom(*run_metadata_);
1203   run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
1204 
1205   return result;
1206 }
1207 
RunMetadataRecordFunction(const std::string & func_name)1208 tensorflow::Status ContextInterface::RunMetadataRecordFunction(
1209     const std::string& func_name) {
1210   const tensorflow::FunctionDef* fdef =
1211       GetEagerContext()->FindFunctionDef(func_name);
1212   if (fdef == nullptr) {
1213     return tensorflow::errors::InvalidArgument(
1214         "Failed to find function \"", func_name, "\" in function library");
1215   }
1216   std::unique_ptr<tensorflow::FunctionBody> fbody;
1217   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1218       *fdef, tensorflow::AttrSlice(), GetEagerContext()->FuncLibDef(), &fbody));
1219   tensorflow::GraphDef def;
1220   fbody->graph->ToGraphDef(&def);
1221   *def.mutable_library() =
1222       GetEagerContext()->FuncLibDef()->ReachableDefinitions(def).ToProto();
1223 
1224   mutex_lock l(run_metadata_mu_);
1225   auto* function_graphs = run_metadata_->add_function_graphs();
1226   *function_graphs->mutable_pre_optimization_graph() = def;
1227   // TODO(b/b/171600738): Figure out a way to record the right post optimization
1228   // graph and partition graph.
1229   *function_graphs->mutable_post_optimization_graph() = def;
1230   *function_graphs->add_partition_graphs() = def;
1231   *run_metadata_->add_partition_graphs() = def;
1232   return ::tensorflow::OkStatus();
1233 }
1234 
SetExecutorForThread(tensorflow::EagerExecutor * executor)1235 void ContextInterface::SetExecutorForThread(
1236     tensorflow::EagerExecutor* executor) {
1237   GetEagerContext()->SetExecutorForThread(executor);
1238 }
1239 
GetCurrentLocation()1240 tfrt::Location AbortLocationHandler::GetCurrentLocation() {
1241   return tfrt::Location(this, GetNextLocationId());
1242 }
1243 
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const1244 void OpAttrsInterface::GetNameAttrList(
1245     tensorflow::NameAttrList* name_and_attrs) const {
1246   fallback_attrs_->FillAttrValueMap(name_and_attrs->mutable_attr());
1247   name_and_attrs->set_name(fallback_attrs_->op_name());
1248 }
1249 
GetTypeList(absl::string_view attr_name,absl::InlinedVector<tensorflow::DataType,4> * type_list) const1250 Status OpAttrsInterface::GetTypeList(
1251     absl::string_view attr_name,
1252     absl::InlinedVector<tensorflow::DataType, 4>* type_list) const {
1253   return tensorflow::errors::Unimplemented("OpAttrsInterface::GetTypeList");
1254 }
1255 
GetInt(absl::string_view attr_name,int64_t * result) const1256 bool OpAttrsInterface::GetInt(absl::string_view attr_name,
1257                               int64_t* result) const {
1258   return attrs_->Get<int64_t>({attr_name.data(), attr_name.size()}, result);
1259 }
1260 
GetFloat(absl::string_view attr_name,float * result) const1261 bool OpAttrsInterface::GetFloat(absl::string_view attr_name,
1262                                 float* result) const {
1263   return attrs_->Get<float>({attr_name.data(), attr_name.size()}, result);
1264 }
1265 
GetBool(absl::string_view attr_name,bool * result) const1266 bool OpAttrsInterface::GetBool(absl::string_view attr_name,
1267                                bool* result) const {
1268   return attrs_->Get<bool>({attr_name.data(), attr_name.size()}, result);
1269 }
1270 
GetType(absl::string_view attr_name,tensorflow::DataType * result) const1271 bool OpAttrsInterface::GetType(absl::string_view attr_name,
1272                                tensorflow::DataType* result) const {
1273   auto optional_type =
1274       attrs_->GetOptional<OpAttrType>({attr_name.data(), attr_name.size()});
1275   if (!optional_type.hasValue()) return false;
1276   *result = tensorflow::tfd::ConvertToTfDataType(optional_type.getValue());
1277   return true;
1278 }
1279 
OperationInterface(ContextInterface * context)1280 OperationInterface::OperationInterface(ContextInterface* context)
1281     : ImmediateExecutionOperation(kTfrt),
1282       op_attrs_(&attrs_, &fallback_attrs_),
1283       context_(context) {}
1284 
Reset(const char * op,const char * raw_device_name)1285 tensorflow::Status OperationInterface::Reset(const char* op,
1286                                              const char* raw_device_name) {
1287   op_name_ = op;
1288   args_.clear();
1289   attrs_.Reset();
1290   custom_device_tensor_handle_count_ = 0;
1291   op_def_ = nullptr;
1292   fallback_attrs_.Reset(op);
1293   stack_trace_.reset();
1294   op_ = nullptr;
1295   function_state_.reset();
1296   tensorflow::Status s = tensorflow::OpDefForOp(op_name_, &op_def_);
1297   is_function_ = !s.ok();
1298   return SetDeviceName(raw_device_name);
1299 }
1300 
Execute(absl::Span<tensorflow::AbstractTensorHandle * > retvals,int * num_retvals)1301 tensorflow::Status OperationInterface::Execute(
1302     absl::Span<tensorflow::AbstractTensorHandle*> retvals, int* num_retvals) {
1303   tensorflow::profiler::TraceMe trace(
1304       [&] {
1305         return absl::StrCat("TFRT_Execute:", Name(), " device:", DeviceName());
1306       },
1307       tensorflow::profiler::TraceMeLevel::kInfo);
1308   if (custom_device_tensor_handle_count_ > 0) {
1309     return tensorflow::errors::InvalidArgument(
1310         "Cannot execute ops that conntains unsupported arg in TFRT.");
1311   }
1312 
1313   TF_RETURN_IF_ERROR(Initialize());
1314   assert(op_ != nullptr || function_state_);
1315   auto* corert = context_->GetCoreRuntime();
1316   auto* chain = context_->GetChain();
1317   auto* host = corert->GetHostContext();
1318   llvm::SmallVector<TensorHandle, 8> th_args;
1319   th_args.reserve(args_.size());
1320 
1321   llvm::SmallVector<TensorHandle, 8> result_ths;
1322   result_ths.resize(*num_retvals);
1323 
1324   if (function_state_) {
1325     // Set up arguments. Check argument dtype synchronously if available.
1326     auto arg_types = function_state_->GetArgTypes();
1327     if (args_.size() != arg_types.size()) {
1328       return tensorflow::errors::InvalidArgument("Expects ", arg_types.size(),
1329                                                  " arguments, but ",
1330                                                  args_.size(), " is provided");
1331     }
1332     auto args_size = args_.size();
1333     for (auto i = 0; i < args_size; ++i) {
1334       th_args.push_back(down_cast<TensorHandleInterface*>(args_[i].get())
1335                             ->Handle()
1336                             .CopyRef());
1337       // TODO(b/173556766): This dtype check is only needed for corert lowering.
1338       // In native lowering, compiler should obtain the argument dtype
1339       // information from FunctionBody directly and lower the op to the native
1340       // kernel that accepts the specified dtype.
1341       if (th_args[i].IsMetadataAvailable()) {
1342         auto arg_dtype = th_args[i].GetAvailableMetadata().dtype;
1343         if (arg_dtype != arg_types[i]) {
1344           return tensorflow::errors::InvalidArgument(
1345               "Expects arg[", i, "] to be ", arg_types[i], " but ", arg_dtype,
1346               " is provided");
1347         }
1348       }
1349     }
1350 
1351     RCReference<RequestContext> request_ctx;
1352     TF_RETURN_IF_ERROR(context_->BuildFunctionRequestContext(
1353         function_state_->GetRunnerTable(), &request_ctx));
1354 
1355     ExecutionContext exec_ctx{std::move(request_ctx),
1356                               abort_location_handler_.GetCurrentLocation()};
1357 
1358     // Make BEF executor to use TfThreadPoolWorkQueue to dispatch kernels.
1359     exec_ctx.set_work_queue(
1360         context_->GetTfrtContext()->GetTfThreadPoolWorkQueue());
1361 
1362     // Execute the function.
1363     function_state_->GetFunc()(exec_ctx, th_args, OpAttrsRef(attrs_),
1364                                result_ths, chain);
1365   } else {
1366     RCReference<RequestContext> request_ctx;
1367     TF_RETURN_IF_ERROR(context_->BuildOpRequestContext(&request_ctx));
1368 
1369     ExecutionContext exec_ctx{std::move(request_ctx),
1370                               abort_location_handler_.GetCurrentLocation()};
1371     for (auto& arg : args_) {
1372       th_args.push_back(
1373           down_cast<TensorHandleInterface*>(arg.get())->Handle().CopyRef());
1374     }
1375     // If the CoreRuntimeOp is a native TFRT op, transfer arguments to target
1376     // device if necessary.
1377     if (!op_->IsFallback()) {
1378       // Get the target device of the arguments that we want to implicitly copy
1379       // to.
1380       auto dst_device_ref = op_->GetDeviceRef();
1381 
1382       for (auto& th_arg : th_args) {
1383         th_arg =
1384             th_arg.TransferTo(exec_ctx, dst_device_ref, op_->GetTensorType());
1385       }
1386     }
1387 
1388     (*op_)(exec_ctx, th_args, OpAttrsRef(attrs_), result_ths, chain);
1389   }
1390 
1391   tensorflow::Status s = ::tensorflow::OkStatus();
1392 
1393   if (TF_PREDICT_FALSE(!this->context_->IsAsync() && !chain->IsAvailable()))
1394     host->Await({chain->CopyRCRef()});
1395 
1396   if (TF_PREDICT_FALSE(chain->IsError())) {
1397     s = CreateTfErrorStatus(chain->GetError());
1398     // TODO(tfrt-devs): Assess if we need a explicit API to clear error.
1399     *chain = GetReadyChain();
1400   }
1401 
1402   for (size_t i = 0, e = result_ths.size(); i != e; ++i) {
1403     auto& th_ref = result_ths[i];
1404     if (TF_PREDICT_FALSE(!this->context_->IsAsync() &&
1405                          !th_ref.GetAsyncTensor()->IsAvailable()))
1406       host->Await(FormRef(th_ref.GetAsyncTensor()));
1407 
1408     // NOTE(fishx): In async mode, we won't report error synchronously even
1409     // though it is possible in TFRT. This is intended to match behavior in
1410     // current TF. However, in the future, we may want to update this
1411     // behavior since synchronous error may improve user experience in async
1412     // mode.
1413     if (TF_PREDICT_FALSE(!this->context_->IsAsync() &&
1414                          th_ref.GetAsyncTensor()->IsError() && s.ok()))
1415       s = CreateTfErrorStatus(th_ref.GetAsyncTensor()->GetError());
1416 
1417     if (function_state_ && context_->IsAsync()) {
1418       retvals[i] = new TensorHandleInterface(function_state_->GetRetTypes()[i],
1419                                              Value(std::move(result_ths[i])),
1420                                              context_->GetTfrtContext());
1421     } else {
1422       retvals[i] = new TensorHandleInterface(Value(std::move(result_ths[i])),
1423                                              context_->GetTfrtContext());
1424     }
1425   }
1426 
1427   return s;
1428 }
1429 
Initialize()1430 tensorflow::Status OperationInterface::Initialize() {
1431   CoreRuntime* corert = context_->GetCoreRuntime();
1432   if (!is_function_) {
1433     // Obtain input arguments' dtype attrs as part of the cache key.
1434     llvm::SmallVector<string_view, 4> dtypes;
1435     attrs_.IterateEntries([&](const OpAttrsRawEntry& entry) {
1436       if (entry.type == OpAttrType::DTYPE && !entry.IsArray())
1437         dtypes.push_back(
1438             GetNameString(*static_cast<const OpAttrType*>(entry.GetData())));
1439     });
1440 
1441     OpHandler* op_handler = nullptr;
1442     TF_RETURN_IF_ERROR(
1443         context_->SelectOpHandlerFromArguments(*this, &op_handler));
1444     Expected<CoreRuntimeOp*> expected_op = context_->GetOpCache().GetOrAddOp(
1445         op_name_, op_handler, device_name_, dtypes, this);
1446     if (!expected_op) {
1447       return tensorflow::errors::InvalidArgument(
1448           StrCat("Cannot obtain CoreRuntimeOp: ", op_name_,
1449                  " on device: ", device_name_, expected_op.takeError()));
1450     }
1451     op_ = expected_op.get();
1452     // Update device name since op_handler_selecter may choose an op_handler
1453     // that's different from what the user specifies.
1454     device_name_ = op_->DeviceName().str();
1455     return ::tensorflow::OkStatus();
1456   }
1457 
1458   bool compile_with_xla = false;
1459   GetFuncAttr(attrs_, op_name_, *context_->GetEagerContext()->FuncLibDef(),
1460               tensorflow::kXlaMustCompileAttr, &compile_with_xla);
1461   // If the function has compile_with_xla==true, we will use RuntimeFallback
1462   // to execute it, since TFRT does not support xla yet.
1463   // TODO(tfrt-devs): Native support of compile_with_xla.
1464   if (compile_with_xla) {
1465     Expected<CoreRuntimeOp*> expected_op =
1466         context_->GetOpCache().GetOrAddXlaOp(op_name_, context_);
1467     if (!expected_op) {
1468       return tensorflow::errors::NotFound(
1469           StrCat("Cannot initialize xla function ", op_name_,
1470                  " on fallback op handler.", expected_op.takeError()));
1471     }
1472     op_ = expected_op.get();
1473     return ::tensorflow::OkStatus();
1474   }
1475 
1476   // Note(fishx): We need eager context for now because we need
1477   // FunctionLibraryDefinition to convert FunctionDef to MLIR TF dialect. In
1478   // the future, when we can generate MLIR from TF Python, we should get rid of
1479   // this.
1480   // FunctionDef -> BEF.
1481   // Look up the cache. Compile BEF and insert to cache if miss.
1482   tensorflow::DeviceSet dev_set;
1483   const DeviceMgr* device_mgr = context_->GetEagerContext()->local_device_mgr();
1484   if (device_mgr == nullptr)
1485     return tensorflow::errors::NotFound("Cannot find device manager");
1486   // TODO(tfrt-devs): support remote devices in TFRT.
1487   for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
1488   if (context_->GetDistributedManager() != nullptr &&
1489       context_->UseTfrtDistributedRuntime()) {
1490     down_cast<DistributedManagerContextInterface*>(
1491         context_->GetDistributedManager())
1492         ->PopulateRemoteDevices(&dev_set);
1493   }
1494   FunctionCache::FunctionCacheResult result;
1495 
1496   tensorflow::TfrtFunctionCompileOptions compile_options;
1497 
1498   // Use the host device if the user does not place the function to a specific
1499   // device.
1500   compile_options.default_device =
1501       device_name_.empty() ? context_->GetEagerContext()->HostCPUName()
1502                            : device_name_;
1503 
1504   // TODO(b/172659131): Do not use TFRT native ops for TF integration for now.
1505   // Re-enable once we have a concrete plan to implement feature complete
1506   // TFRT native ops (kernels).
1507   compile_options.enable_native_ops = false;
1508 
1509   if (fallback_attrs_.NumAttributes() > 0) {
1510     const auto& ndef = NodeDef();
1511     // TODO(tfrt-devs): If we are to create more attributes, consider packing
1512     // them into a proto.
1513     {
1514       const auto& it = ndef.attr().find(kEnableNativeOpsAttr);
1515       if (it != ndef.attr().end()) {
1516         compile_options.enable_native_ops = it->second.b();
1517       }
1518     }
1519 
1520     {
1521       const auto& it = ndef.attr().find(kEnableGrapplerAttr);
1522       if (it != ndef.attr().end()) {
1523         compile_options.enable_grappler = it->second.b();
1524       }
1525     }
1526   }
1527 
1528   llvm::SmallVector<const tfrt::Device*, 4> input_devices;
1529   input_devices.reserve(args_.size());
1530   for (auto& arg : args_) {
1531     auto arg_th = down_cast<TensorHandleInterface*>(arg.get())->Handle();
1532     if (!arg_th.IsDeviceAvailable()) {
1533       corert->GetHostContext()->Await(arg_th.GetAsyncDevice().CopyRCRef());
1534     }
1535     input_devices.push_back(down_cast<TensorHandleInterface*>(arg.get())
1536                                 ->Handle()
1537                                 .GetAvailableDevice()
1538                                 .get());
1539   }
1540   TF_RETURN_IF_ERROR(context_->GetFunctionCache().GetOrAddFunction(
1541       op_name_, device_name_, dev_set, context_->GetEagerContext(), corert,
1542       /*request_ctx_fn=*/
1543       [this](tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
1544              RCReference<RequestContext>* request_ctx) {
1545         return context_->BuildFunctionRequestContext(runner_table, request_ctx);
1546       },
1547       abort_location_handler_.GetCurrentLocation(), compile_options,
1548       input_devices, &result));
1549   // TODO(tfrt-devs): Avoid calling EagerContext::ShouldStoreGraphs().
1550   if (result.is_cache_miss &&
1551       context_->GetEagerContext()->ShouldStoreGraphs()) {
1552     TF_RETURN_IF_ERROR(context_->RunMetadataRecordFunction(op_name_));
1553   }
1554   function_state_ = std::move(result.function_state);
1555   return ::tensorflow::OkStatus();
1556 }
1557 
SetDeviceName(const char * name)1558 tensorflow::Status OperationInterface::SetDeviceName(const char* name) {
1559   if (op_ && name != device_name_) {
1560     return tensorflow::errors::Internal(
1561         "Failed to update device name. Right now TFRT cannot update device "
1562         "name of a fallback op if it is initialized.");
1563   }
1564   device_name_ = name ? name : "";
1565   return ::tensorflow::OkStatus();
1566 }
1567 
AddInput(tensorflow::AbstractTensorHandle * input)1568 tensorflow::Status OperationInterface::AddInput(
1569     tensorflow::AbstractTensorHandle* input) {
1570   tensorflow::ImmediateExecutionTensorHandle* h =
1571       down_cast<tensorflow::ImmediateExecutionTensorHandle*>(input);
1572   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1573   if (tensorflow::CustomDeviceTensorHandle::classof(h)) {
1574     custom_device_tensor_handle_count_++;
1575   }
1576   h->Ref();
1577   args_.push_back(
1578       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1579           h));
1580   return ::tensorflow::OkStatus();
1581 }
1582 
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)1583 tensorflow::Status OperationInterface::SetInput(
1584     size_t index, tensorflow::ImmediateExecutionTensorHandle* input) {
1585   if (index >= args_.size()) {
1586     return tensorflow::errors::InvalidArgument("Index >= inputs.size: %d >= %d",
1587                                                index, args_.size());
1588   }
1589   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1590   if (tensorflow::CustomDeviceTensorHandle::classof(args_[index].get())) {
1591     custom_device_tensor_handle_count_--;
1592   }
1593   if (tensorflow::CustomDeviceTensorHandle::classof(input)) {
1594     custom_device_tensor_handle_count_++;
1595   }
1596   input->Ref();
1597   args_[index] =
1598       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1599           input);
1600   return ::tensorflow::OkStatus();
1601 }
1602 
AddInputList(absl::Span<tensorflow::AbstractTensorHandle * const> inputs)1603 tensorflow::Status OperationInterface::AddInputList(
1604     absl::Span<tensorflow::AbstractTensorHandle* const> inputs) {
1605   return tensorflow::errors::Unimplemented(
1606       "Unimplemented OperationInterface::AddInputList");
1607 }
1608 
1609 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const>
GetInputs() const1610 OperationInterface::GetInputs() const {
1611   return absl::MakeSpan(
1612       reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
1613           args_.data()),
1614       args_.size());
1615 }
1616 
SetAttrString(const char * attr_name,const char * data,size_t length)1617 tensorflow::Status OperationInterface::SetAttrString(const char* attr_name,
1618                                                      const char* data,
1619                                                      size_t length) {
1620   fallback_attrs_.Set(attr_name, tensorflow::StringPiece(data, length));
1621   if (attrs_.SetString(attr_name, string_view(data, length)))
1622     return ::tensorflow::OkStatus();
1623   return tensorflow::errors::Internal(
1624       "OperationInterface::SetAttrString failed");
1625 }
1626 
SetAttrInt(const char * attr_name,int64_t value)1627 tensorflow::Status OperationInterface::SetAttrInt(const char* attr_name,
1628                                                   int64_t value) {
1629   fallback_attrs_.Set(attr_name, static_cast<int64_t>(value));
1630   if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1631   return tensorflow::errors::Internal("OperationInterface::SetAttrInt failed");
1632 }
1633 
SetAttrFloat(const char * attr_name,float value)1634 tensorflow::Status OperationInterface::SetAttrFloat(const char* attr_name,
1635                                                     float value) {
1636   fallback_attrs_.Set(attr_name, value);
1637   if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1638   return tensorflow::errors::Internal(
1639       "OperationInterface::SetAttrFloat failed");
1640 }
1641 
SetAttrBool(const char * attr_name,bool value)1642 tensorflow::Status OperationInterface::SetAttrBool(const char* attr_name,
1643                                                    bool value) {
1644   fallback_attrs_.Set(attr_name, value);
1645   if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1646   return tensorflow::errors::Internal("OperationInterface::SetAttrBool failed");
1647 }
1648 
SetAttrType(const char * attr_name,tensorflow::DataType value)1649 tensorflow::Status OperationInterface::SetAttrType(const char* attr_name,
1650                                                    tensorflow::DataType value) {
1651   fallback_attrs_.Set(attr_name, value);
1652   if (value == tensorflow::DT_INVALID) {
1653     return tensorflow::errors::InvalidArgument(
1654         "OperationInterface::SetAttrType failed to set DT_INVALID");
1655   }
1656   if (attrs_.Set(attr_name,
1657                  tfrt::GetOpAttrTypeFromDType(
1658                      tensorflow::tfd::ConvertTfDataTypeToBefAttrType(value))))
1659     return ::tensorflow::OkStatus();
1660   // TODO(fishx): Remove this workaround once we support all dtype in TF.
1661   // This is fine for now since attribute "T", "U", "Tidx" is not used by TFRT
1662   // native ops.
1663   if (std::strcmp(attr_name, "T") == 0 || std::strcmp(attr_name, "U") == 0 ||
1664       std::strcmp(attr_name, "Tidx") == 0) {
1665     return ::tensorflow::OkStatus();
1666   }
1667   return tensorflow::errors::Internal("OperationInterface::SetAttrType failed");
1668 }
1669 
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)1670 tensorflow::Status OperationInterface::SetAttrShape(const char* attr_name,
1671                                                     const int64_t* dims,
1672                                                     const int num_dims) {
1673   // NOTE: This is copied from EagerOperation::SetAttrShape.
1674   // TODO(b/154554118): Remove the duplication.
1675   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1676     return tensorflow::errors::InvalidArgument(
1677         "Value specified for `", attr_name, "` has ", num_dims,
1678         " dimensions which is over the limit of ",
1679         tensorflow::TensorShape::MaxDimensions(), ".");
1680   }
1681 
1682   tensorflow::TensorShapeProto proto;
1683   size_t offset;
1684   if (num_dims < 0) {
1685     proto.set_unknown_rank(true);
1686 
1687     // Set unranked ShapeAttr.
1688     offset = bef_attr_encoder_.EncodeUnrankedShapeAttr();
1689   } else {
1690     for (int d = 0; d < num_dims; ++d) {
1691       proto.add_dim()->set_size(dims[d]);
1692     }
1693 
1694     // Set RankedShapeAttr.
1695     offset = bef_attr_encoder_.EncodeRankedShapeAttr(
1696         llvm::makeArrayRef(dims, num_dims));
1697   }
1698   fallback_attrs_.Set(attr_name, proto);
1699 
1700   auto buf = bef_attr_encoder_.TakeResult();
1701   tfrt::ShapeAttr shape_attr(buf.data() + offset);
1702   // TODO(tfrt-devs): Avoid the copy.
1703   if (attrs_.Set(attr_name, shape_attr)) return ::tensorflow::OkStatus();
1704 
1705   return tensorflow::errors::Internal(
1706       "OperationInterface::SetAttrShape failed");
1707 }
1708 
SetAttrFunction(const char * attr_name,const tensorflow::AbstractOperation * value)1709 tensorflow::Status OperationInterface::SetAttrFunction(
1710     const char* attr_name, const tensorflow::AbstractOperation* value) {
1711   auto* value_operation = down_cast<const OperationInterface*>(value);
1712   // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1713   // Consider removing this and rely on TFRT OpAttrs.
1714   tensorflow::AttrValue attr_value;
1715   tensorflow::NameAttrList* func = attr_value.mutable_func();
1716   func->set_name(value->Name());
1717   fallback_attrs_.Set(attr_name, attr_value);
1718 
1719   if (attrs_.SetFunc(attr_name, {string_view(value_operation->Name())}))
1720     return ::tensorflow::OkStatus();
1721 
1722   return tensorflow::errors::Internal(
1723       "OperationInterface::SetAttrFunction failed");
1724 }
1725 
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)1726 tensorflow::Status OperationInterface::SetAttrFunctionName(
1727     const char* attr_name, const char* data, size_t length) {
1728   // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1729   // Consider removing this and rely on TFRT OpAttrs.
1730   tensorflow::AttrValue attr_value;
1731   tensorflow::NameAttrList* func = attr_value.mutable_func();
1732   func->set_name(data);
1733   fallback_attrs_.Set(attr_name, attr_value);
1734 
1735   if (attrs_.SetFunc(attr_name, {data})) return ::tensorflow::OkStatus();
1736 
1737   return tensorflow::errors::Internal(
1738       "OperationInterface::SetAttrFunctionName failed");
1739 }
1740 
SerializeTFETensorToDenseAttr(tensorflow::AbstractTensorInterface * tensor,tfrt::BefAttrEncoder * encoder)1741 static size_t SerializeTFETensorToDenseAttr(
1742     tensorflow::AbstractTensorInterface* tensor,
1743     tfrt::BefAttrEncoder* encoder) {
1744   std::vector<uint8_t> data;
1745 
1746   const auto element_type =
1747       tensorflow::tfd::ConvertTfDataTypeToBefAttrType(tensor->Type());
1748   llvm::SmallVector<int64_t, 4> shape;
1749   for (int i = 0; i < tensor->NumDims(); ++i) {
1750     shape.push_back(tensor->Dim(i));
1751   }
1752   auto elements = llvm::makeArrayRef(
1753       reinterpret_cast<const uint8_t*>(tensor->Data()), tensor->ByteSize());
1754   return encoder->EncodeDenseAttr(static_cast<DType>(element_type), shape,
1755                                   elements);
1756 }
1757 
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)1758 tensorflow::Status OperationInterface::SetAttrTensor(
1759     const char* attr_name, tensorflow::AbstractTensorInterface* tensor) {
1760   tfrt::BefAttrEncoder encoder;
1761   const size_t offset = SerializeTFETensorToDenseAttr(tensor, &encoder);
1762   auto buffer = encoder.TakeResult();
1763   DenseAttr dense_attr(buffer.data() + offset);
1764   if (attrs_.Set(attr_name, dense_attr)) return ::tensorflow::OkStatus();
1765 
1766   return tensorflow::errors::Internal(
1767       "OperationInterface::SetAttrTensor failed");
1768 }
1769 
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1770 tensorflow::Status OperationInterface::SetAttrStringList(
1771     const char* attr_name, const void* const* values, const size_t* lengths,
1772     int num_values) {
1773   std::vector<tensorflow::StringPiece> v(num_values);
1774   for (int i = 0; i < num_values; ++i) {
1775     v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1776                                    lengths[i]);
1777   }
1778   fallback_attrs_.Set(attr_name, v);
1779 
1780   tfrt::BefAttrEncoder encoder;
1781   const size_t offset =
1782       encoder.EncodeStringListAttr(values, lengths, num_values);
1783   auto buf = encoder.TakeResult();
1784   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1785   // TODO(tfrt-devs): Avoid the copy.
1786   if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1787 
1788   return tensorflow::errors::Internal(
1789       "OperationInterface::SetAttrStringList failed");
1790 }
1791 
SetAttrFloatList(const char * attr_name,const float * values,int num_values)1792 tensorflow::Status OperationInterface::SetAttrFloatList(const char* attr_name,
1793                                                         const float* values,
1794                                                         int num_values) {
1795   fallback_attrs_.Set(
1796       attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1797 
1798   if (attrs_.SetArray(attr_name, tfrt::ArrayRef<float>(values, num_values)))
1799     return ::tensorflow::OkStatus();
1800   return tensorflow::errors::Internal(
1801       "OperationInterface::SetAttrFloatList failed");
1802 }
1803 
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)1804 tensorflow::Status OperationInterface::SetAttrIntList(const char* attr_name,
1805                                                       const int64_t* values,
1806                                                       int num_values) {
1807   fallback_attrs_.Set(
1808       attr_name, tensorflow::gtl::ArraySlice<const int64_t>(
1809                      reinterpret_cast<const int64_t*>(values), num_values));
1810 
1811   if (attrs_.SetArray(attr_name, tfrt::ArrayRef<int64_t>(values, num_values)))
1812     return ::tensorflow::OkStatus();
1813 
1814   return tensorflow::errors::Internal(
1815       "OperationInterface::SetAttrIntList failed");
1816 }
1817 
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)1818 tensorflow::Status OperationInterface::SetAttrTypeList(
1819     const char* attr_name, const tensorflow::DataType* values, int num_values) {
1820   fallback_attrs_.Set(attr_name,
1821                       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1822                           values, num_values));
1823   // Convert to OpAttrType first.
1824   llvm::SmallVector<tfrt::DType, 4> tfrt_dtypes;
1825   tfrt_dtypes.reserve(num_values);
1826   for (int i = 0; i < num_values; ++i) {
1827     tfrt_dtypes.push_back(
1828         tensorflow::tfd::ConvertTfDataTypeToBefAttrType(values[i]));
1829   }
1830 
1831   if (attrs_.SetRaw(attr_name, tfrt_dtypes.data(), tfrt::OpAttrType::DTYPE,
1832                     num_values, OpAttrsRawEntryType::kArray))
1833     return ::tensorflow::OkStatus();
1834 
1835   return tensorflow::errors::Internal(
1836       "OperationInterface::SetAttrTypeList failed");
1837 }
1838 
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)1839 tensorflow::Status OperationInterface::SetAttrBoolList(
1840     const char* attr_name, const unsigned char* values, int num_values) {
1841   std::unique_ptr<bool[]> b(new bool[num_values]);
1842   for (int i = 0; i < num_values; ++i) {
1843     b[i] = values[i];
1844   }
1845   fallback_attrs_.Set(
1846       attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1847 
1848   // Convert to bool first.
1849   llvm::SmallVector<bool, 4> bool_array;
1850   bool_array.reserve(num_values);
1851   for (int i = 0; i < num_values; ++i) {
1852     bool_array.push_back(static_cast<bool>((values[i])));
1853   }
1854   if (attrs_.SetArray(attr_name,
1855                       tfrt::ArrayRef<bool>(bool_array.data(), num_values)))
1856     return ::tensorflow::OkStatus();
1857 
1858   return tensorflow::errors::Internal(
1859       "OperationInterface::SetAttrBoolList failed");
1860 }
1861 
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)1862 tensorflow::Status OperationInterface::SetAttrShapeList(const char* attr_name,
1863                                                         const int64_t** dims,
1864                                                         const int* num_dims,
1865                                                         int num_values) {
1866   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1867       new tensorflow::TensorShapeProto[num_values]);
1868   for (int i = 0; i < num_values; ++i) {
1869     const auto num_dims_i = num_dims[i];
1870 
1871     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1872       return tensorflow::errors::InvalidArgument(
1873           StrCat("Value specified for `", attr_name, "` has ", num_dims_i,
1874                  " dimensions which is over the limit of ",
1875                  tensorflow::TensorShape::MaxDimensions(), "."));
1876     }
1877     if (num_dims_i < 0) {
1878       proto[i].set_unknown_rank(true);
1879     } else {
1880       const int64_t* dims_i = dims[i];
1881       auto proto_i = &proto[i];
1882       for (int d = 0; d < num_dims_i; ++d) {
1883         proto_i->add_dim()->set_size(dims_i[d]);
1884       }
1885     }
1886   }
1887   fallback_attrs_.Set(attr_name,
1888                       tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1889                           proto.get(), num_values));
1890 
1891   BefAttrEncoder encoder;
1892   const size_t offset = encoder.EncodeShapeListAttr(dims, num_dims, num_values);
1893   auto buf = encoder.TakeResult();
1894   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1895   if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1896 
1897   return tensorflow::errors::Internal(
1898       "OperationInterface::SetAttrShapeList failed");
1899 }
1900 
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)1901 tensorflow::Status OperationInterface::SetAttrFunctionList(
1902     const char* attr_name, absl::Span<const AbstractOperation*> values) {
1903   size_t num_values = values.size();
1904   std::vector<const void*> func_attrs(num_values);
1905   std::vector<size_t> lengths(num_values);
1906 
1907   for (int i = 0; i < num_values; ++i) {
1908     auto* value_operation = down_cast<const OperationInterface*>(values[i]);
1909     lengths[i] = value_operation->Name().length();
1910     func_attrs[i] = value_operation->Name().c_str();
1911   }
1912 
1913   // Encode the array of function attributes with BEF typed attribute encoder to
1914   // an aggregated attribute.
1915   BefAttrEncoder encoder;
1916   const size_t offset =
1917       encoder.EncodeFuncListAttr(func_attrs.data(), lengths.data(), num_values);
1918   auto buf = encoder.TakeResult();
1919   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1920   if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1921 
1922   return tensorflow::errors::Internal(
1923       "OperationInterface::SetAttrFunctionList failed");
1924 }
1925 
InputLength(const char * input_name,int * length)1926 tensorflow::Status OperationInterface::InputLength(const char* input_name,
1927                                                    int* length) {
1928   return tensorflow::errors::Unimplemented(
1929       "Unimplemented OperationInterface::InputLength");
1930 }
1931 
OutputLength(const char * output_name,int * length)1932 tensorflow::Status OperationInterface::OutputLength(const char* output_name,
1933                                                     int* length) {
1934   return tensorflow::errors::Unimplemented(
1935       "Unimplemented OperationInterface::OutputLength");
1936 }
1937 
GetOpAttrs() const1938 const tensorflow::AbstractOpAttrs* OperationInterface::GetOpAttrs() const {
1939   return &op_attrs_;
1940 }
1941 
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)1942 void OperationInterface::AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) {
1943   auto* tfrt_op_attrs = down_cast<const OpAttrsInterface*>(op_attrs);
1944   tfrt_op_attrs->GetAttrs()->IterateEntries(
1945       [this](const OpAttrsRawEntry& entry) {
1946         attrs_.SetRaw(entry.name, entry.GetData(), entry.type,
1947                       entry.element_count, entry.entry_type);
1948       });
1949   fallback_attrs_.CopyAttributes(*tfrt_op_attrs->GetFallbackAttrs());
1950 }
1951 
MaybeInferInputAttrs()1952 void OperationInterface::MaybeInferInputAttrs() {
1953   if (!op_def_) return;
1954   for (int i = 0; i < args_.size(); i++) {
1955     auto* handle = args_[i].get();
1956     const auto& input_def = op_def_->input_arg(i);
1957     if (!input_def.number_attr().empty() ||
1958         !input_def.type_list_attr().empty()) {
1959       // Some clients that are still setting their input attributes manually are
1960       // adding input list to their op by calling `TFE_OpAddInput` for each of
1961       // its elements instead of calling `TFE_OpAddInputList`. When this
1962       // happens, we cannot detect the end of such list, thus lose track of the
1963       // input arguments in the op definition. To guarantee backward
1964       // compatibility with those clients, disable automatic inference in this
1965       // case.
1966       return;
1967     }
1968     const std::string& type_attr = input_def.type_attr();
1969     if (!type_attr.empty()) {
1970       bool success = attrs_.Set(
1971           type_attr, tfrt::GetOpAttrTypeFromDType(
1972                          tensorflow::tfd::ConvertTfDataTypeToBefAttrType(
1973                              handle->DataType())));
1974       if (success) {
1975         fallback_attrs_.Set(type_attr, handle->DataType());
1976       }
1977     }
1978   }
1979 }
1980 
1981 }  // namespace tf
1982 }  // namespace tfrt
1983