• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
17 
18 #include <cstddef>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/immediate_execution_context.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tf_tensor_internal.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
30 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/casts.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/stringpiece.h"
48 #include "tensorflow/core/protobuf/error_codes.pb.h"
49 #include "tensorflow/core/public/session_options.h"
50 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
51 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
52 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
53 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
54 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
55 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
56 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_interface.h"
57 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_registry.h"
58 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
59 #include "tensorflow/core/tfrt/eager/virtual_device.h"
60 #include "tensorflow/core/tfrt/utils/error_util.h"
61 #include "tensorflow/core/tfrt/utils/utils.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63 #include "tfrt/common/compat/eigen/eigen_dtype.h"  // from @tf_runtime
64 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
65 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
66 #include "tfrt/core_runtime/op_attr_type.h"  // from @tf_runtime
67 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
68 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
69 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
70 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
71 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
72 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
73 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
74 #include "tfrt/host_context/chain.h"  // from @tf_runtime
75 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
76 #include "tfrt/host_context/device.h"  // from @tf_runtime
77 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
78 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
79 #include "tfrt/host_context/function.h"  // from @tf_runtime
80 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
81 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
82 #include "tfrt/host_context/location.h"  // from @tf_runtime
83 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
84 #include "tfrt/metrics/common_metrics.h"  // from @tf_runtime
85 #include "tfrt/support/error_util.h"  // from @tf_runtime
86 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
87 #include "tfrt/support/ref_count.h"  // from @tf_runtime
88 #include "tfrt/support/string_util.h"  // from @tf_runtime
89 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
90 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
91 #include "tfrt/tensor/dense_host_tensor_view.h"  // from @tf_runtime
92 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
93 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
94 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
95 #include "tfrt/tensor/tensor_serialize_utils.h"  // from @tf_runtime
96 #include "tfrt/tensor/tensor_type_registration.h"  // from @tf_runtime
97 
98 namespace tfrt {
99 namespace tf {
100 
101 namespace {
102 
103 using tensorflow::down_cast;
104 
105 constexpr char kGpuDeviceName[] = "GPU";
106 constexpr char kEnableNativeOpsAttr[] = "TFRT_TEST_enable_native_ops";
107 constexpr char kEnableGrapplerAttr[] = "TFRT_TEST_enable_grappler";
108 
CreateMetadata(DType dtype,absl::Span<const ssize_t> dim_sizes)109 TensorMetadata CreateMetadata(DType dtype,
110                               absl::Span<const ssize_t> dim_sizes) {
111   return TensorMetadata(DType(dtype),
112                         TensorShape(llvm::ArrayRef<ssize_t>(
113                             reinterpret_cast<const ssize_t*>(dim_sizes.data()),
114                             dim_sizes.size())));
115 }
116 
ConvertDType(DType kind)117 tensorflow::DataType ConvertDType(DType kind) {
118   switch (kind) {
119     case DType::UI8:
120       return tensorflow::DT_UINT8;
121     case DType::UI16:
122       return tensorflow::DT_UINT16;
123     case DType::UI32:
124       return tensorflow::DT_UINT32;
125     case DType::UI64:
126       return tensorflow::DT_UINT64;
127     case DType::I8:
128       return tensorflow::DT_INT8;
129     case DType::I16:
130       return tensorflow::DT_INT16;
131     case DType::I32:
132       return tensorflow::DT_INT32;
133     case DType::I64:
134       return tensorflow::DT_INT64;
135     case DType::BF16:
136       return tensorflow::DT_BFLOAT16;
137     case DType::F16:
138       return tensorflow::DT_HALF;
139     case DType::F32:
140       return tensorflow::DT_FLOAT;
141     case DType::F64:
142       return tensorflow::DT_DOUBLE;
143     case DType::I1:
144       return tensorflow::DT_BOOL;
145     case DType::Complex64:
146       return tensorflow::DT_COMPLEX64;
147     case DType::Complex128:
148       return tensorflow::DT_COMPLEX128;
149     case DType::String:
150       return tensorflow::DT_STRING;
151     case DType::Resource:
152       return tensorflow::DT_RESOURCE;
153     case DType::Variant:
154       return tensorflow::DT_VARIANT;
155     case DType::QUI8:
156       return tensorflow::DT_QUINT8;
157     case DType::QUI16:
158       return tensorflow::DT_QUINT16;
159     case DType::QI8:
160       return tensorflow::DT_QINT8;
161     case DType::QI16:
162       return tensorflow::DT_QINT16;
163     case DType::QI32:
164       return tensorflow::DT_QINT32;
165     default:
166       LOG(ERROR) << "Unsupported kind " << kind;
167       return tensorflow::DT_INVALID;
168   }
169 }
170 
ConvertDType(tensorflow::DataType dtype)171 DType ConvertDType(tensorflow::DataType dtype) {
172   switch (dtype) {
173     case tensorflow::DT_UINT8:
174       return static_cast<DType>(DType::UI8);
175     case tensorflow::DT_UINT16:
176       return static_cast<DType>(DType::UI16);
177     case tensorflow::DT_UINT32:
178       return static_cast<DType>(DType::UI32);
179     case tensorflow::DT_UINT64:
180       return static_cast<DType>(DType::UI64);
181     case tensorflow::DT_INT8:
182       return static_cast<DType>(DType::I8);
183     case tensorflow::DT_INT16:
184       return static_cast<DType>(DType::I16);
185     case tensorflow::DT_INT32:
186       return static_cast<DType>(DType::I32);
187     case tensorflow::DT_INT64:
188       return static_cast<DType>(DType::I64);
189     case tensorflow::DT_BFLOAT16:
190       return static_cast<DType>(DType::BF16);
191     case tensorflow::DT_HALF:
192       return static_cast<DType>(DType::F16);
193     case tensorflow::DT_FLOAT:
194       return static_cast<DType>(DType::F32);
195     case tensorflow::DT_DOUBLE:
196       return static_cast<DType>(DType::F64);
197     case tensorflow::DT_BOOL:
198       return static_cast<DType>(DType::I1);
199     case tensorflow::DT_STRING:
200       return static_cast<DType>(DType::String);
201     case tensorflow::DT_COMPLEX64:
202       return static_cast<DType>(DType::Complex64);
203     case tensorflow::DT_COMPLEX128:
204       return static_cast<DType>(DType::Complex128);
205     case tensorflow::DT_RESOURCE:
206       return static_cast<DType>(DType::Resource);
207     case tensorflow::DT_VARIANT:
208       return static_cast<DType>(DType::Variant);
209     case tensorflow::DT_QUINT8:
210       return static_cast<DType>(DType::QUI8);
211     case tensorflow::DT_QUINT16:
212       return static_cast<DType>(DType::QUI16);
213     case tensorflow::DT_QINT8:
214       return static_cast<DType>(DType::QI8);
215     case tensorflow::DT_QINT16:
216       return static_cast<DType>(DType::QI16);
217     case tensorflow::DT_QINT32:
218       return static_cast<DType>(DType::QI32);
219     default:
220       LOG(FATAL) << "Unsupported dtype " << dtype;
221   }
222 }
223 
ConvertDTypeToOpAttrType(tensorflow::DataType dtype)224 OpAttrType ConvertDTypeToOpAttrType(tensorflow::DataType dtype) {
225   switch (dtype) {
226     case tensorflow::DT_UINT8:
227       return OpAttrType::UI8;
228     case tensorflow::DT_UINT16:
229       return OpAttrType::UI16;
230     case tensorflow::DT_UINT32:
231       return OpAttrType::UI32;
232     case tensorflow::DT_UINT64:
233       return OpAttrType::UI64;
234     case tensorflow::DT_INT8:
235       return OpAttrType::I8;
236     case tensorflow::DT_INT16:
237       return OpAttrType::I16;
238     case tensorflow::DT_INT32:
239       return OpAttrType::I32;
240     case tensorflow::DT_INT64:
241       return OpAttrType::I64;
242     case tensorflow::DT_BFLOAT16:
243       return OpAttrType::BF16;
244     case tensorflow::DT_HALF:
245       return OpAttrType::F16;
246     case tensorflow::DT_FLOAT:
247       return OpAttrType::F32;
248     case tensorflow::DT_DOUBLE:
249       return OpAttrType::F64;
250     case tensorflow::DT_BOOL:
251       return OpAttrType::BOOL;
252     case tensorflow::DT_COMPLEX64:
253       return OpAttrType::COMPLEX64;
254     case tensorflow::DT_COMPLEX128:
255       return OpAttrType::COMPLEX128;
256     default:
257       LOG(FATAL) << "Unsupported dtype " << dtype;
258   }
259 }
260 
261 // This method will first look at the calling op attrs and then look at the
262 // function def attrs to find the attribute value.
GetFuncAttr(const OpAttrs & op_attrs,const std::string & op_name,const tensorflow::FunctionLibraryDefinition & func_lib_def,string_view attr_name,bool * value)263 void GetFuncAttr(const OpAttrs& op_attrs, const std::string& op_name,
264                  const tensorflow::FunctionLibraryDefinition& func_lib_def,
265                  string_view attr_name, bool* value) {
266   bool success = op_attrs.Get(attr_name, value);
267   if (success) {
268     DVLOG(2) << "Caller explicitly specifies " << attr_name.str()
269              << (value ? "=true " : "=false, ");
270     return;
271   }
272 
273   const tensorflow::FunctionDef* function_def = func_lib_def.Find(op_name);
274   if (function_def == nullptr) {
275     return;
276   }
277 
278   tensorflow::Status status =
279       GetNodeAttr(tensorflow::AttrSlice(&function_def->attr()),
280                   {attr_name.data(), attr_name.size()}, value);
281   if (status.ok()) {
282     DVLOG(2) << "Function definition explicitly specifies " << attr_name.str()
283              << (value ? "=true" : "=false");
284     return;
285   }
286 }
287 
GetNextLocationId()288 int64_t GetNextLocationId() {
289   static std::atomic<int64_t> id(0);
290   return id.fetch_add(1, std::memory_order_relaxed);
291 }
292 }  // namespace
293 
Type() const294 tensorflow::DataType TensorInterface::Type() const {
295   auto kind = tensor_.get().metadata().dtype;
296   if (kind == DType::Unsupported) {
297     assert(llvm::isa<tensorflow::tfd::RuntimeFallbackTensor>(tensor_.get()));
298     return tensor_.get<tensorflow::tfd::RuntimeFallbackTensor>()
299         .GetTensorHandle()
300         ->DataType();
301   }
302   return ConvertDType(kind);
303 }
304 
NumDims() const305 int TensorInterface::NumDims() const { return tensor_.get().shape().GetRank(); }
306 
Dim(int dim_index) const307 int64_t TensorInterface::Dim(int dim_index) const {
308   return tensor_.get().shape().GetDimensionSize(dim_index);
309 }
310 
NumElements() const311 int64_t TensorInterface::NumElements() const {
312   if (!tensor_) {
313     return static_cast<int64_t>(tf_tensor_.NumElements());
314   }
315   return tensor_.get().shape().GetNumElements();
316 }
317 
ByteSize() const318 size_t TensorInterface::ByteSize() const {
319   return tensor_.get().metadata().GetHostSizeInBytes();
320 }
321 
Data() const322 void* TensorInterface::Data() const {
323   if (!tensor_) {
324     return tensorflow::TensorCApi::Buffer(tf_tensor_)->data();
325   } else {
326     auto& tensor = tensor_.get<DenseHostTensor>();
327     return tensor.data();
328   }
329 }
330 
331 // TFRT DenseHostTensor is always aligned
IsAligned() const332 bool TensorInterface::IsAligned() const { return true; }
333 
CanMove() const334 bool TensorInterface::CanMove() const {
335   // It is safe to move the Tensor if and only if we own the unique reference to
336   // the tensor buffer.
337   auto& dht = tensor_.get<DenseHostTensor>();
338   return tensor_.IsUnique() && dht.buffer()->IsUnique();
339 }
340 
SummarizeValue() const341 std::string TensorInterface::SummarizeValue() const {
342   if (!tensor_) {
343     return tf_tensor_.SummarizeValue(/*max_entries=*/3, /*print_v2=*/true);
344   } else {
345     std::string result;
346     llvm::raw_string_ostream result_ostream(result);
347     tensor_->Print(result_ostream);
348     return result;
349   }
350 }
351 
TensorRef() const352 AsyncValueRef<Tensor> TensorInterface::TensorRef() const {
353   return tensor_.CopyRef();
354 }
355 
TensorHandleInterface(Value && v,CoreRuntime * corert)356 TensorHandleInterface::TensorHandleInterface(Value&& v, CoreRuntime* corert)
357     : ImmediateExecutionTensorHandle(kTfrt),
358       corert_(*corert),
359       value_(std::move(v)) {}
360 
DataType() const361 tensorflow::DataType TensorHandleInterface::DataType() const {
362   auto metadata = Metadata();
363   if (!metadata.hasValue()) {
364     LOG(ERROR)
365         << "Failed to get DataType due to error metadata: "
366         << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
367     return tensorflow::DT_INVALID;
368   }
369   auto kind = metadata.getValue()->dtype;
370   if (kind == DType::Unsupported) {
371     AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
372     if (!async_tensor->IsAvailable()) {
373       corert_.GetHostContext()->Await(FormRef(async_tensor));
374     }
375 
376     if (async_tensor->IsError()) {
377       LOG(ERROR) << "Failed to get DataType from an error tensor "
378                  << async_tensor->GetError().message;
379       return tensorflow::DT_INVALID;
380     }
381     assert(async_tensor->IsType<tensorflow::tfd::RuntimeFallbackTensor>());
382     return async_tensor->get<tensorflow::tfd::RuntimeFallbackTensor>()
383         .GetTensorHandle()
384         ->DataType();
385   }
386   return ConvertDType(kind);
387 }
388 
Shape(tensorflow::PartialTensorShape * shape) const389 tensorflow::Status TensorHandleInterface::Shape(
390     tensorflow::PartialTensorShape* shape) const {
391   auto metadata = Metadata();
392   if (!metadata.hasValue()) {
393     return CreateTfErrorStatus(
394         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
395   }
396   int num_dims = metadata.getValue()->shape.GetRank();
397   if (num_dims == -1) {
398     return tensorflow::Status::OK();
399   }
400   SmallVector<ssize_t, 8> dims;
401   metadata.getValue()->shape.GetDimensions(&dims);
402   TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
403   return tensorflow::Status::OK();
404 }
405 
NumDims(int * num_dims) const406 tensorflow::Status TensorHandleInterface::NumDims(int* num_dims) const {
407   auto metadata = Metadata();
408   if (!metadata.hasValue()) {
409     return CreateTfErrorStatus(
410         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
411   }
412   *num_dims = metadata.getValue()->shape.GetRank();
413 
414   return tensorflow::Status::OK();
415 }
416 
NumElements(int64_t * num_elements) const417 tensorflow::Status TensorHandleInterface::NumElements(
418     int64_t* num_elements) const {
419   auto metadata = Metadata();
420   if (!metadata.hasValue()) {
421     return CreateTfErrorStatus(
422         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
423   }
424   *num_elements = metadata.getValue()->shape.GetNumElements();
425 
426   return tensorflow::Status::OK();
427 }
428 
Dim(int dim_index,int64_t * dim) const429 tensorflow::Status TensorHandleInterface::Dim(int dim_index,
430                                               int64_t* dim) const {
431   auto metadata = Metadata();
432   if (!metadata.hasValue()) {
433     return CreateTfErrorStatus(
434         value_.get<TensorHandle>().GetAsyncMetadata().GetError());
435   }
436   *dim = metadata.getValue()->shape.GetDimensionSize(dim_index);
437 
438   return tensorflow::Status::OK();
439 }
440 
DeviceName(tensorflow::Status * status) const441 const char* TensorHandleInterface::DeviceName(
442     tensorflow::Status* status) const {
443   auto& th = value_.get<TensorHandle>();
444   if (!th.IsDeviceAvailable()) {
445     corert_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
446   }
447   if (th.IsDeviceError()) {
448     *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
449     return nullptr;
450   }
451   return th.GetAvailableDevice()->name().data();
452 }
453 
BackingDeviceName(tensorflow::Status * status) const454 const char* TensorHandleInterface::BackingDeviceName(
455     tensorflow::Status* status) const {
456   return DeviceName(status);
457 }
458 
DeviceType(tensorflow::Status * status) const459 const char* TensorHandleInterface::DeviceType(
460     tensorflow::Status* status) const {
461   auto& th = value_.get<TensorHandle>();
462   if (!th.IsDeviceAvailable()) {
463     corert_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
464   }
465   if (th.IsDeviceError()) {
466     *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
467     return nullptr;
468   }
469   return th.GetAvailableDevice()->type().name().data();
470 }
471 
Resolve(tensorflow::Status * status)472 tensorflow::AbstractTensorInterface* TensorHandleInterface::Resolve(
473     tensorflow::Status* status) {
474   auto* host_ctx = corert_.GetHostContext();
475   auto host_device_ref = host_ctx->GetHostDeviceRef();
476   auto& th = value_.get<TensorHandle>();
477 
478   auto tensor_av = th.GetAsyncTensor();
479   if (!tensor_av->IsAvailable()) {
480     host_ctx->Await(FormRef(tensor_av));
481   }
482   if (auto* error = tensor_av->GetErrorIfPresent()) {
483     *status = CreateTfErrorStatus(*error);
484     return nullptr;
485   }
486   assert(th.IsMetadataAvailable());
487 
488   if (th.GetAsyncTensor()->get<Tensor>().tensor_type() ==
489       StringHostTensor::kTensorType) {
490     tensorflow::Tensor tf_tensor =
491         tensorflow::tfd::CopyShtToTfTensor(tensor_av->get<StringHostTensor>());
492     return new tensorflow::TensorInterface(tf_tensor);
493   }
494 
495   // Convert the tensor to DenseHostTensor.
496   auto req_ctx =
497       tfrt::RequestContextBuilder(host_ctx, /*resource_context=*/nullptr)
498           .build();
499   if (!req_ctx) {
500     *status = tensorflow::Status(
501         tensorflow::error::Code::UNKNOWN,
502         StrCat("Failed to build a RequestContext: ", req_ctx.takeError()));
503     return nullptr;
504   }
505   tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
506   auto target_th = th.TransferTo(exec_ctx, std::move(host_device_ref),
507                                  DenseHostTensor::kTensorType);
508 
509   auto target_av = target_th.GetAsyncTensor();
510   if (!target_av->IsAvailable()) {
511     host_ctx->Await(FormRef(target_av));
512   }
513   if (target_av->IsError()) {
514     *status = tensorflow::Status(
515         tensorflow::error::Code::UNKNOWN,
516         StrCat("Cannot resolve tensor: ", target_av->GetError().message));
517     return nullptr;
518   }
519   auto host_tensor_ref = target_th.ReleaseTensorRef();
520   return new TensorInterface(std::move(host_tensor_ref));
521 }
522 
Metadata() const523 llvm::Optional<const TensorMetadata*> TensorHandleInterface::Metadata() const {
524   auto& th = value_.get<TensorHandle>();
525   if (!th.IsMetadataAvailable()) {
526     corert_.GetHostContext()->Await(th.GetAsyncMetadata().CopyRCRef());
527   }
528   if (th.IsMetadataError()) {
529     return llvm::None;
530   }
531   return &th.GetAvailableMetadata();
532 }
533 
ContextInterface(const tensorflow::SessionOptions & opts,tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,bool is_async,bool use_tfrt_distributed_runtime)534 ContextInterface::ContextInterface(
535     const tensorflow::SessionOptions& opts,
536     tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
537     bool is_async, bool use_tfrt_distributed_runtime)
538     : ImmediateExecutionContext(kTfrt),
539       context_(opts, default_device_placement_policy, is_async),
540       use_tfrt_distributed_runtime_(use_tfrt_distributed_runtime) {
541   LOG(INFO) << "TFRT Enabled";
542   metrics::AddTFRTVersionMetric();
543 
544   op_handler_selector_ = std::make_unique<EagerOpHandlerSelector>(
545       GetCoreRuntime(), GetEagerContext(), GetFallbackOpHandler(),
546       GetEagerContext()->PinSmallOpsToCPU());
547 
548   run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
549 }
550 
~ContextInterface()551 ContextInterface::~ContextInterface() {}
552 
GetChain()553 AsyncValueRef<Chain>* ContextInterface::GetChain() {
554   auto thread_id = std::this_thread::get_id();
555   {
556     tensorflow::tf_shared_lock l(chain_map_mu_);
557     auto it = thread_local_chain_.find(thread_id);
558     if (it != thread_local_chain_.end()) {
559       return &it->second;
560     }
561   }
562   {
563     tensorflow::mutex_lock l(chain_map_mu_);
564     if (thread_local_chain_.find(thread_id) == thread_local_chain_.end()) {
565       auto chain = GetReadyChain(GetHostContext());
566       thread_local_chain_[thread_id] = std::move(chain);
567     }
568     return &thread_local_chain_[thread_id];
569   }
570 }
571 
572 template <typename T>
MakeScalarTensor(T value,HostContext * host)573 static TensorInterface* MakeScalarTensor(T value, HostContext* host) {
574   // The TensorInterface implementation assumes the tensor is a DenseHostTensor,
575   // so we need to use a DenseHostTensor to represent a scalar tensor.
576   TensorMetadata md(GetDType<T>(), {});
577   auto t = DenseHostTensor::CreateUninitialized(md, host);
578   if (!t) {
579     LOG(ERROR) << "Failed to create DenseHostTensor";
580     return nullptr;
581   }
582   auto& dht = t.getValue();
583   MutableDHTArrayView<T> view{&dht};
584   view.Elements()[0] = value;
585 
586   return new TensorInterface(
587       MakeAvailableAsyncValueRef<DenseHostTensor>(host, std::move(dht)));
588 }
589 
CreateInt64Scalar(int64_t value)590 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt64Scalar(
591     int64_t value) {
592   return MakeScalarTensor(value, GetHostContext());
593 }
594 
CreateUint64Scalar(uint64_t value)595 tensorflow::AbstractTensorInterface* ContextInterface::CreateUint64Scalar(
596     uint64_t value) {
597   return MakeScalarTensor(value, GetHostContext());
598 }
599 
CreateInt32Scalar(int32_t value)600 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt32Scalar(
601     int32_t value) {
602   return MakeScalarTensor(value, GetHostContext());
603 }
604 
CreateFloatScalar(float value)605 tensorflow::AbstractTensorInterface* ContextInterface::CreateFloatScalar(
606     float value) {
607   return MakeScalarTensor(value, GetHostContext());
608 }
609 
CreateDoubleScalar(double value)610 tensorflow::AbstractTensorInterface* ContextInterface::CreateDoubleScalar(
611     double value) {
612   return MakeScalarTensor(value, GetHostContext());
613 }
614 
CreateHalfScalar(Eigen::half value)615 tensorflow::AbstractTensorInterface* ContextInterface::CreateHalfScalar(
616     Eigen::half value) {
617   return MakeScalarTensor(value, GetHostContext());
618 }
619 
CreateStringScalar(tensorflow::tstring value)620 tensorflow::AbstractTensorInterface* ContextInterface::CreateStringScalar(
621     tensorflow::tstring value) {
622   auto* host = GetHostContext();
623   TensorMetadata md(DType(DType::String), {});
624   auto t = StringHostTensor::MakeConstructedAsyncValueRef(md, host);
625   if (t.IsError()) {
626     LOG(ERROR) << "Failed to create StringHostTensor";
627     return nullptr;
628   }
629   t->strings()[0] = value;
630 
631   t.SetStateConcrete();
632   return new TensorInterface(std::move(t));
633 }
634 
CreateComplex128Scalar(tensorflow::complex128 value)635 tensorflow::AbstractTensorInterface* ContextInterface::CreateComplex128Scalar(
636     tensorflow::complex128 value) {
637   return MakeScalarTensor(value, GetHostContext());
638 }
639 
CreateBoolScalar(bool value)640 tensorflow::AbstractTensorInterface* ContextInterface::CreateBoolScalar(
641     bool value) {
642   return MakeScalarTensor(value, GetHostContext());
643 }
644 
CreateTensor(tensorflow::DataType dtype,absl::Span<const int64_t> dim_sizes)645 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
646     tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) {
647   std::vector<ssize_t> dimvec(dim_sizes.size());
648   for (int i = 0; i < dim_sizes.size(); ++i) {
649     dimvec[i] = static_cast<int64_t>(dim_sizes[i]);
650   }
651 
652   TensorMetadata md;
653   switch (dtype) {
654     case tensorflow::DT_UINT8:
655       md = CreateMetadata(DType::UI8, dimvec);
656       break;
657     case tensorflow::DT_INT8:
658       md = CreateMetadata(DType::I8, dimvec);
659       break;
660     case tensorflow::DT_INT16:
661       md = CreateMetadata(DType::I16, dimvec);
662       break;
663     case tensorflow::DT_INT32:
664       md = CreateMetadata(DType::I32, dimvec);
665       break;
666     case tensorflow::DT_INT64:
667       md = CreateMetadata(DType::I64, dimvec);
668       break;
669     case tensorflow::DT_HALF:
670       md = CreateMetadata(DType::F16, dimvec);
671       break;
672     case tensorflow::DT_FLOAT:
673       md = CreateMetadata(DType::F32, dimvec);
674       break;
675     case tensorflow::DT_DOUBLE:
676       md = CreateMetadata(DType::F64, dimvec);
677       break;
678     case tensorflow::DT_BOOL:
679       md = CreateMetadata(DType::I1, dimvec);
680       break;
681     case tensorflow::DT_COMPLEX64:
682       md = CreateMetadata(DType::Complex64, dimvec);
683       break;
684     case tensorflow::DT_COMPLEX128:
685       md = CreateMetadata(DType::Complex128, dimvec);
686       break;
687     case tensorflow::DT_VARIANT:
688       // Note: TF Python API can create variant tensor for ragged tensor.
689       md = CreateMetadata(DType::Variant, dimvec);
690       break;
691     case tensorflow::DT_STRING:
692       // No TFRT Metadata needed for non-scalar string tensors.
693       break;
694     default:
695       LOG(ERROR) << "Cannot create tensor with dtype: " << dtype;
696       return nullptr;
697   }
698 
699   if (dtype == tensorflow::DT_STRING) {
700     // Create Tensorflow Tensor as a buffer for tstrings.
701     return new TensorInterface(
702         tensorflow::Tensor(dtype, tensorflow::TensorShape(dim_sizes)));
703   } else {
704     auto t = DenseHostTensor::CreateUninitialized(md, GetHostContext());
705     return new TensorInterface(MakeAvailableAsyncValueRef<DenseHostTensor>(
706         GetHostContext(), std::move(t.getValue())));
707   }
708 }
709 
CreateTensor(tensorflow::DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)710 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
711     tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
712     size_t len, MemoryReleaser memory_releaser, void* memory_releaser_arg) {
713   TensorMetadata metadata(ConvertDType(dtype),
714                           {dims, static_cast<size_t>(num_dims)});
715   RCReference<HostBuffer> buffer = HostBuffer::CreateFromExternal(
716       data, len,
717       [memory_releaser, memory_releaser_arg](void* data, size_t len) {
718         memory_releaser(data, len, memory_releaser_arg);
719       });
720   AsyncValueRef<DenseHostTensor> dht =
721       MakeConstructedAsyncValueRef<DenseHostTensor>(GetHostContext(), metadata,
722                                                     std::move(buffer));
723 
724   dht.SetStateConcrete();
725   return new TensorInterface(std::move(dht));
726 }
727 
UsesTFRT()728 bool ContextInterface::UsesTFRT() { return true; }
729 
CreateLocalHandle(tensorflow::AbstractTensorInterface * t)730 tensorflow::ImmediateExecutionTensorHandle* ContextInterface::CreateLocalHandle(
731     tensorflow::AbstractTensorInterface* t) {
732   auto* tensor_interface = down_cast<TensorInterface*>(t);
733   auto* host = GetHostContext();
734 
735   // Create RuntimeFallbackTensor from a TF Tensor, and then create
736   // the according TensorHandleInterface.
737   if (tensor_interface->IsTfTensor()) {
738     tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
739         tensorflow::TensorHandle::CreateLocalHandle(
740             tensor_interface->TfTensor())};
741 
742     auto expected_result_tensor =
743         tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
744             std::move(tf_tensor_handle), GetHostContext());
745 
746     if (expected_result_tensor) {
747       return new TensorHandleInterface(
748           Value(TensorHandle(
749               host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
750               MakeAvailableAsyncValueRef<
751                   tensorflow::tfd::RuntimeFallbackTensor>(
752                   host, std::move(expected_result_tensor.get())))),
753           GetCoreRuntime());
754     } else {
755       return new TensorHandleInterface(
756           Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
757               GetHostContext(), StrCat(expected_result_tensor.takeError())))),
758           GetCoreRuntime());
759     }
760   }
761 
762   auto tensor_av = tensor_interface->TensorRef();
763   const TensorMetadata& md = tensor_av.get<Tensor>().metadata();
764 
765   // NOTE(fishx): Following logic is needed to let TF-TFRT fully reach
766   // performance parity with current TF. This API is used to by tf.constant
767   // to convert Python object to **CPU** Tensor. tf.constant in current TF
768   // heavily depends on Tensor Mirroring feature for good performance. However,
769   // TFRT does not have Tensor Mirroring feature. In order to use Tensor
770   // Mirroring from current TF runtime, we convert the result of tf.constant to
771   // Fallback Tensor.
772 
773   if (tensor_av.IsAvailable()) {
774     if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
775       return new TensorHandleInterface(
776           Value(TensorHandle(
777               host->GetHostDeviceRef(), md,
778               MakeAvailableAsyncValueRef<
779                   tensorflow::tfd::RuntimeFallbackTensor>(
780                   host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
781                             *dht, host)))),
782           GetCoreRuntime());
783     }
784   } else {
785     auto result_tensor = MakeIndirectAsyncValue(host);
786     tensor_av.AndThen([host, result_tensor = result_tensor.CopyRef(),
787                        tensor_av = tensor_av.CopyRef()]() {
788       if (auto* dht =
789               llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
790         result_tensor->ForwardTo(
791             MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
792                 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
793                           *dht, host)));
794       } else {
795         result_tensor->ForwardTo(tensor_av.CopyRef());
796       }
797     });
798     return new TensorHandleInterface(
799         Value(TensorHandle(host->GetHostDeviceRef(), md,
800                            AsyncValueRef<Tensor>(std::move(result_tensor)))),
801         GetCoreRuntime());
802   }
803   return new TensorHandleInterface(
804       Value(TensorHandle(host->GetHostDeviceRef(), md, std::move(tensor_av))),
805       GetCoreRuntime());
806 }
807 
808 tensorflow::ImmediateExecutionTensorHandle*
CreateLocalHandleFromTFTensor(tensorflow::Tensor & t,const char * d_name)809 ContextInterface::CreateLocalHandleFromTFTensor(tensorflow::Tensor& t,
810                                                 const char* d_name) {
811   auto* host = GetHostContext();
812   // Create RuntimeFallbackTensor from a TF Tensor, and then create
813   // the according TensorHandleInterface.
814   tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
815       tensorflow::TensorHandle::CreateLocalHandle(std::move(t))};
816 
817   tfrt::Expected<tensorflow::tfd::RuntimeFallbackTensor>
818       expected_result_tensor =
819           tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
820               std::move(tf_tensor_handle), GetHostContext());
821 
822   if (expected_result_tensor) {
823     return new TensorHandleInterface(
824         Value(TensorHandle(
825             host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
826             MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
827                 host, std::move(expected_result_tensor.get())))),
828         GetCoreRuntime());
829   } else {
830     return new TensorHandleInterface(
831         Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
832             GetHostContext(), StrCat(expected_result_tensor.takeError())))),
833         GetCoreRuntime());
834   }
835 }
836 
837 tensorflow::ImmediateExecutionTensorHandle*
TFTensorHandleFromInterface(tensorflow::ImmediateExecutionTensorHandle * handle)838 ContextInterface::TFTensorHandleFromInterface(
839     tensorflow::ImmediateExecutionTensorHandle* handle) {
840   TensorHandle th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
841   AsyncValue* tensor_av = th.GetAsyncTensor();
842   if (tensor_av->IsUnavailable()) GetHostContext()->Await(FormRef(tensor_av));
843 
844   auto& tensor = th.GetAsyncTensor()->get<Tensor>();
845 
846   if (auto* rtfbt =
847           llvm::dyn_cast<tensorflow::tfd::RuntimeFallbackTensor>(&tensor))
848     return rtfbt->GetTensorHandle();
849 
850   if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(&tensor)) {
851     return tensorflow::TensorHandle::CreateLocalHandle(
852         tensorflow::tfd::MoveHostBufferToTfTensor(dht->buffer().CopyRef(),
853                                                   dht->dtype(), dht->shape()));
854   }
855 
856   if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(&tensor)) {
857     return tensorflow::TensorHandle::CreateLocalHandle(
858         tensorflow::tfd::CopyShtToTfTensor(*sht));
859   }
860 
861   LOG(ERROR) << "Unsupported tensor type";
862   return nullptr;
863 }
864 
CreateOperation()865 tensorflow::ImmediateExecutionOperation* ContextInterface::CreateOperation() {
866   return new OperationInterface(this);
867 }
868 
869 // TODO(srbs): Change this to directly fetch the MLIR function once that is
870 // supported.
RegisterFunction(tensorflow::AbstractFunction * f)871 tensorflow::Status ContextInterface::RegisterFunction(
872     tensorflow::AbstractFunction* f) {
873   tensorflow::FunctionDef* fdef;
874   TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
875   if (!fdef) {
876     return tensorflow::errors::InvalidArgument(
877         "GetFunctionDef returned nullptr.");
878   }
879   return AddFunctionDef(*fdef);
880 }
881 
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)882 void ContextInterface::ListDevices(
883     std::vector<tensorflow::DeviceAttributes>* devices) {
884   context_.GetEagerContext()->ListDevices(devices);
885 }
886 
AddDevices(std::vector<std::unique_ptr<tensorflow::Device>> devices)887 tensorflow::Status ContextInterface::AddDevices(
888     std::vector<std::unique_ptr<tensorflow::Device>> devices) {
889   if (!devices.empty() && devices[0]->device_type() != "CPU")
890     return tensorflow::errors::InvalidArgument(
891         "Device: ", devices[0]->device_type(), " is not allowed to be added ",
892         "after the context is initialized. Currently allowed device: CPU. ",
893         "May update this API to allow adding more types of devices.");
894 
895   for (const auto& d : devices) {
896     GetHostContext()->GetDeviceManager()->MaybeAddDevice(
897         TakeRef(new CpuDevice(d->name())));
898   }
899   TF_RETURN_IF_ERROR(GetEagerContext()->AddDevices(std::move(devices)));
900 
901   return tensorflow::Status::OK();
902 }
903 
ClearCachesAndThreadExecutors()904 void ContextInterface::ClearCachesAndThreadExecutors() {
905   GetEagerContext()->ClearCachesAndThreadExecutors();
906   GetHostContext()->Quiesce();
907 }
908 
StartStep()909 void ContextInterface::StartStep() { GetEagerContext()->StartStep(); }
910 
EndStep()911 void ContextInterface::EndStep() { GetEagerContext()->EndStep(); }
912 
EnableCollectiveOps(const tensorflow::ServerDef & server_def)913 tensorflow::Status ContextInterface::EnableCollectiveOps(
914     const tensorflow::ServerDef& server_def) {
915   if (use_tfrt_distributed_runtime_) {
916     return distributed_manager_->EnableCollectiveOps(server_def);
917   }
918   // Preserve the local virtual device names, since local virtual devices are
919   // added by TFRT and we need to add it back after worker server is
920   // initialized. Currently one such use case is the TPU_SYSTEM device, which
921   // is a virtual device specifically used to initialize TPUs.
922   std::vector<std::string> virtual_device_names;
923 
924   for (const auto& d :
925        GetHostContext()->GetDeviceManager()->ListDevices<Device>()) {
926     if (d->IsDeviceType(tfrt::VirtualDevice::kDeviceType)) {
927       tensorflow::DeviceNameUtils::ParsedName p;
928       if (!tensorflow::DeviceNameUtils::ParseFullName(d->name().str(), &p)) {
929         return tensorflow::errors::InvalidArgument(
930             "Invalid local virtual device name: ", d->name().str());
931       }
932 
933       virtual_device_names.push_back(tensorflow::DeviceNameUtils::FullName(
934           server_def.job_name(), /*replica=*/0, server_def.task_index(), p.type,
935           p.id));
936     }
937   }
938 
939   TF_RETURN_IF_ERROR(GetEagerContext()->EnableCollectiveOps(server_def));
940 
941   // Create new devices with updated device name.
942   std::vector<std::unique_ptr<tensorflow::Device>> dummy_tf_devices;
943   CreateDummyTfDevices(virtual_device_names, &dummy_tf_devices);
944 
945   std::string name_prefix =
946       absl::StrCat("/job:", server_def.job_name(),
947                    "/replica:0/task:", server_def.task_index());
948 
949   // Update host device in TFRT HostContext.
950   GetHostContext()->ResetHostDevice(
951       GetHostContext()
952           ->GetDeviceManager()
953           ->MaybeAddDevice(TakeRef(
954               new CpuDevice(absl::StrCat(name_prefix, "/device:CPU:0"))))
955           .release());
956 
957   // Update virtual devices in TFRT HostContext.
958   AddDummyTfrtDevices(virtual_device_names, GetHostContext());
959 
960   // Update eager context's device manager.
961   auto* local_device_mgr = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
962       GetEagerContext()->local_device_mgr());
963   TF_RETURN_IF_ERROR(local_device_mgr->AddDevices(std::move(dummy_tf_devices)));
964 
965   return tensorflow::Status::OK();
966 }
967 
BuildFunctionRequestContext(tensorflow::tfd::OpKernelRunnerTable * runner_table,RCReference<tfrt::RequestContext> * request_context)968 tensorflow::Status ContextInterface::BuildFunctionRequestContext(
969     tensorflow::tfd::OpKernelRunnerTable* runner_table,
970     RCReference<tfrt::RequestContext>* request_context) {
971   auto* step_container = GetEagerContext()->StepContainer();
972   RequestContextBuilder request_context_builder(
973       GetHostContext(), GetResourceContext(), step_container->StepId());
974 
975   TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
976       &request_context_builder, runner_table, GetEagerContext()));
977   if (distributed_manager_ != nullptr) {
978     down_cast<DistributedManagerContextInterface*>(distributed_manager_.get())
979         ->UpdateRequestContextBuilder(&request_context_builder);
980   }
981   auto expected_request_context = std::move(request_context_builder).build();
982   if (!expected_request_context) {
983     return tensorflow::errors::Internal(
984         StrCat(expected_request_context.takeError()));
985   }
986   *request_context = std::move(expected_request_context.get());
987   return tensorflow::Status::OK();
988 }
989 
BuildOpRequestContext(RCReference<tfrt::RequestContext> * request_context)990 tensorflow::Status ContextInterface::BuildOpRequestContext(
991     RCReference<tfrt::RequestContext>* request_context) {
992   return BuildFunctionRequestContext(/*runner_table=*/nullptr, request_context);
993 }
994 
995 tensorflow::ImmediateExecutionTensorHandle*
CopyTensorHandleToDevice(tensorflow::ImmediateExecutionTensorHandle * handle,const char * device_name,tensorflow::Status * status)996 ContextInterface::CopyTensorHandleToDevice(
997     tensorflow::ImmediateExecutionTensorHandle* handle, const char* device_name,
998     tensorflow::Status* status) {
999   auto* host_ctx = GetHostContext();
1000 
1001   TensorHandle src_th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
1002 
1003   auto tfrt_device_name =
1004       ConvertTfDeviceNameToTfrt(device_name, GetEagerContext());
1005   if (!tfrt_device_name) {
1006     *status = tensorflow::errors::InvalidArgument(
1007         StrCat(tfrt_device_name.takeError()));
1008     RCReference<AsyncValue> error_av =
1009         MakeErrorAsyncValueRef(host_ctx, status->error_message());
1010     return new TensorHandleInterface(
1011         Value(TensorHandle::CreateError(std::move(error_av))),
1012         GetCoreRuntime());
1013   }
1014   auto dst_device_ref = host_ctx->GetDeviceManager()->GetDeviceRef<Device>(
1015       tfrt_device_name.get());
1016   if (!dst_device_ref) {
1017     std::string error_message =
1018         tfrt::StrCat("Failed to find destination device with name: ",
1019                      tfrt_device_name.get());
1020     *status = tensorflow::errors::Internal(error_message);
1021     RCReference<AsyncValue> error_av =
1022         MakeErrorAsyncValueRef(host_ctx, error_message);
1023     return new TensorHandleInterface(
1024         Value(TensorHandle::CreateError(std::move(error_av))),
1025         GetCoreRuntime());
1026   }
1027 
1028   RCReference<RequestContext> request_ctx;
1029   *status = BuildOpRequestContext(&request_ctx);
1030   if (!status->ok()) return nullptr;
1031 
1032   ExecutionContext exec_ctx{std::move(request_ctx)};
1033 
1034   auto target_th =
1035       src_th.TransferToInferredType(exec_ctx, std::move(dst_device_ref));
1036 
1037   auto target_av = target_th.GetAsyncTensor();
1038   if (target_av->IsError()) {
1039     *status = tensorflow::errors::Internal(
1040         tfrt::StrCat("Copying to device <", tfrt_device_name.get(),
1041                      "> failed: ", target_av->GetError().message));
1042     return nullptr;
1043   }
1044   return new TensorHandleInterface(Value(target_th.CopyRef()),
1045                                    GetCoreRuntime());
1046 }
1047 
AddFunctionDef(const tensorflow::FunctionDef & fdef)1048 tensorflow::Status ContextInterface::AddFunctionDef(
1049     const tensorflow::FunctionDef& fdef) {
1050   return GetEagerContext()->AddFunctionDef(fdef);
1051 }
1052 
AddFunctionDefWithStackTraces(const tensorflow::FunctionDef & fdef,const tensorflow::StackTracesMap & stack_traces)1053 tensorflow::Status ContextInterface::AddFunctionDefWithStackTraces(
1054     const tensorflow::FunctionDef& fdef,
1055     const tensorflow::StackTracesMap& stack_traces) {
1056   return GetEagerContext()->AddFunctionDefWithStackTraces(fdef, stack_traces);
1057 }
1058 
ListFunctionNames()1059 std::vector<std::string> ContextInterface::ListFunctionNames() {
1060   return GetEagerContext()->ListFunctionNames();
1061 }
1062 
RemoveFunction(const std::string & func)1063 tensorflow::Status ContextInterface::RemoveFunction(const std::string& func) {
1064   // TODO(tfrt-devs): We need to ensure all invocations of this function is
1065   // finished before removing it.
1066   function_cache_.RemoveFunction(func);
1067   return GetEagerContext()->RemoveFunction(func);
1068 }
1069 
FindFunctionDef(const std::string & name) const1070 const tensorflow::FunctionDef* ContextInterface::FindFunctionDef(
1071     const std::string& name) const {
1072   return GetEagerContext()->FindFunctionDef(name);
1073 }
1074 
1075 const tensorflow::DeviceNameUtils::ParsedName&
HostCPUParsedName() const1076 ContextInterface::HostCPUParsedName() const {
1077   return context_.HostCPUParsedName();
1078 }
1079 
HostCPUName() const1080 const std::string& ContextInterface::HostCPUName() const {
1081   return context_.GetEagerContext()->HostCPUName();
1082 }
1083 
1084 tensorflow::CustomDeviceOpHandler&
GetCustomDeviceOpHandler()1085 ContextInterface::GetCustomDeviceOpHandler() {
1086   return context_.GetEagerContext()->GetCustomDeviceOpHandler();
1087 }
1088 
RegisterCustomDevice(const std::string & name,std::unique_ptr<tensorflow::CustomDevice> device)1089 tensorflow::Status ContextInterface::RegisterCustomDevice(
1090     const std::string& name, std::unique_ptr<tensorflow::CustomDevice> device) {
1091   return context_.GetEagerContext()->RegisterCustomDevice(name,
1092                                                           std::move(device));
1093 }
1094 
FuncLibDef()1095 tensorflow::FunctionLibraryDefinition* ContextInterface::FuncLibDef() {
1096   return context_.GetEagerContext()->FuncLibDef();
1097 }
1098 
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)1099 void ContextInterface::SetReuseRendezvousForFunctions(
1100     bool reuse_rendezvous_for_functions) {
1101   // TODO(fishx): This feature doesn't work properly in TFRT yet. Fix it.
1102   context_.GetEagerContext()->SetReuseRendezvousForFunctions(
1103       reuse_rendezvous_for_functions);
1104 }
1105 
ResetGlobalRendezvousForFunction()1106 void ContextInterface::ResetGlobalRendezvousForFunction() {
1107   context_.GetEagerContext()->ResetGlobalRendezvousForFunction();
1108 }
1109 
GetLoggedOpsTestonly()1110 std::vector<std::string> ContextInterface::GetLoggedOpsTestonly() {
1111   const auto& ret = GetHostContext()
1112                         ->GetOrCreateSharedContext<tensorflow::tfd::OpLogger>()
1113                         .GetLoggedOps();
1114   return std::vector<std::string>(ret.begin(), ret.end());
1115 }
1116 
GetHostContext()1117 HostContext* ContextInterface::GetHostContext() {
1118   return GetCoreRuntime()->GetHostContext();
1119 }
1120 
GetEagerContext()1121 tensorflow::EagerContext* ContextInterface::GetEagerContext() {
1122   return context_.GetEagerContext();
1123 }
1124 
GetEagerContext() const1125 const tensorflow::EagerContext* ContextInterface::GetEagerContext() const {
1126   return context_.GetEagerContext();
1127 }
1128 
GetCoreRuntime()1129 CoreRuntime* ContextInterface::GetCoreRuntime() {
1130   return context_.GetCoreRuntime();
1131 }
1132 
GetFallbackOpHandler()1133 OpHandler* ContextInterface::GetFallbackOpHandler() {
1134   return context_.GetFallbackOpHandler();
1135 }
1136 
GetResourceContext()1137 ResourceContext* ContextInterface::GetResourceContext() {
1138   return context_.GetResourceContext();
1139 }
1140 
SelectOpHandlerFromArguments(const tensorflow::ImmediateExecutionOperation & op,OpHandler ** op_handler)1141 tensorflow::Status ContextInterface::SelectOpHandlerFromArguments(
1142     const tensorflow::ImmediateExecutionOperation& op, OpHandler** op_handler) {
1143   return op_handler_selector_->SelectFromArguments(op, op_handler);
1144 }
1145 
SelectOpHandlerFromNodeDef(const tensorflow::ImmediateExecutionOperation & op,const NodeDef * node_def,OpHandler ** op_handler)1146 tensorflow::Status ContextInterface::SelectOpHandlerFromNodeDef(
1147     const tensorflow::ImmediateExecutionOperation& op, const NodeDef* node_def,
1148     OpHandler** op_handler) {
1149   return op_handler_selector_->SelectFromNodeDef(op, node_def, op_handler);
1150 }
1151 
ExportRunMetadata()1152 std::unique_ptr<tensorflow::RunMetadata> ContextInterface::ExportRunMetadata() {
1153   mutex_lock l(run_metadata_mu_);
1154 
1155   // NOTE(fishx): We need to merge run_metadata from TF Eager Context because
1156   // right now we still use current TF runtime to execute graph (e.g. tf.data
1157   // via fallback).
1158   auto result = GetEagerContext()->ExportRunMetadata();
1159   result->MergeFrom(*run_metadata_);
1160   run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
1161 
1162   return result;
1163 }
1164 
RunMetadataRecordFunction(const std::string & func_name)1165 tensorflow::Status ContextInterface::RunMetadataRecordFunction(
1166     const std::string& func_name) {
1167   const tensorflow::FunctionDef* fdef =
1168       GetEagerContext()->FindFunctionDef(func_name);
1169   if (fdef == nullptr) {
1170     return tensorflow::errors::InvalidArgument(
1171         "Failed to find function \"", func_name, "\" in function library");
1172   }
1173   std::unique_ptr<tensorflow::FunctionBody> fbody;
1174   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1175       *fdef, tensorflow::AttrSlice(), GetEagerContext()->FuncLibDef(), &fbody));
1176   tensorflow::GraphDef def;
1177   fbody->graph->ToGraphDef(&def);
1178   *def.mutable_library() =
1179       GetEagerContext()->FuncLibDef()->ReachableDefinitions(def).ToProto();
1180 
1181   mutex_lock l(run_metadata_mu_);
1182   auto* function_graphs = run_metadata_->add_function_graphs();
1183   *function_graphs->mutable_pre_optimization_graph() = def;
1184   // TODO(b/b/171600738): Figure out a way to record the right post optimization
1185   // graph and partition graph.
1186   *function_graphs->mutable_post_optimization_graph() = def;
1187   *function_graphs->add_partition_graphs() = def;
1188   *run_metadata_->add_partition_graphs() = def;
1189   return tensorflow::Status::OK();
1190 }
1191 
SetExecutorForThread(tensorflow::EagerExecutor * executor)1192 void ContextInterface::SetExecutorForThread(
1193     tensorflow::EagerExecutor* executor) {
1194   GetEagerContext()->SetExecutorForThread(executor);
1195 }
1196 
GetCurrentLocation()1197 tfrt::Location AbortLocationHandler::GetCurrentLocation() {
1198   return tfrt::Location(this, GetNextLocationId());
1199 }
1200 
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const1201 void OpAttrsInterface::GetNameAttrList(
1202     tensorflow::NameAttrList* name_and_attrs) const {
1203   fallback_attrs_->FillAttrValueMap(name_and_attrs->mutable_attr());
1204   name_and_attrs->set_name(fallback_attrs_->op_name());
1205 }
1206 
GetTypeList(absl::string_view attr_name,absl::InlinedVector<tensorflow::DataType,4> * type_list) const1207 Status OpAttrsInterface::GetTypeList(
1208     absl::string_view attr_name,
1209     absl::InlinedVector<tensorflow::DataType, 4>* type_list) const {
1210   return tensorflow::errors::Unimplemented("OpAttrsInterface::GetTypeList");
1211 }
1212 
GetInt(absl::string_view attr_name,int64_t * result) const1213 bool OpAttrsInterface::GetInt(absl::string_view attr_name,
1214                               int64_t* result) const {
1215   return attrs_->Get<int64_t>({attr_name.data(), attr_name.size()}, result);
1216 }
1217 
GetFloat(absl::string_view attr_name,float * result) const1218 bool OpAttrsInterface::GetFloat(absl::string_view attr_name,
1219                                 float* result) const {
1220   return attrs_->Get<float>({attr_name.data(), attr_name.size()}, result);
1221 }
1222 
GetBool(absl::string_view attr_name,bool * result) const1223 bool OpAttrsInterface::GetBool(absl::string_view attr_name,
1224                                bool* result) const {
1225   return attrs_->Get<bool>({attr_name.data(), attr_name.size()}, result);
1226 }
1227 
GetType(absl::string_view attr_name,tensorflow::DataType * result) const1228 bool OpAttrsInterface::GetType(absl::string_view attr_name,
1229                                tensorflow::DataType* result) const {
1230   auto optional_type =
1231       attrs_->GetOptional<OpAttrType>({attr_name.data(), attr_name.size()});
1232   if (!optional_type.hasValue()) return false;
1233   *result = tensorflow::tfd::ConvertToTfDataType(optional_type.getValue());
1234   return true;
1235 }
1236 
OperationInterface(ContextInterface * context)1237 OperationInterface::OperationInterface(ContextInterface* context)
1238     : ImmediateExecutionOperation(kTfrt),
1239       op_attrs_(&attrs_, &fallback_attrs_),
1240       context_(context) {}
1241 
Reset(const char * op,const char * raw_device_name)1242 tensorflow::Status OperationInterface::Reset(const char* op,
1243                                              const char* raw_device_name) {
1244   op_name_ = op;
1245   args_.clear();
1246   attrs_.Reset();
1247   custom_device_tensor_handle_count_ = 0;
1248   op_def_ = nullptr;
1249   fallback_attrs_.Reset(op);
1250   stack_trace_.reset();
1251   op_ = nullptr;
1252   function_state_.reset();
1253   tensorflow::Status s = tensorflow::OpDefForOp(op_name_, &op_def_);
1254   is_function_ = !s.ok();
1255   return SetDeviceName(raw_device_name);
1256 }
1257 
Execute(absl::Span<tensorflow::AbstractTensorHandle * > retvals,int * num_retvals)1258 tensorflow::Status OperationInterface::Execute(
1259     absl::Span<tensorflow::AbstractTensorHandle*> retvals, int* num_retvals) {
1260   if (custom_device_tensor_handle_count_ > 0) {
1261     return tensorflow::errors::InvalidArgument(
1262         "Cannot execute ops that conntains unsupported arg in TFRT.");
1263   }
1264 
1265   TF_RETURN_IF_ERROR(Initialize());
1266   assert(op_ != nullptr || function_state_);
1267   auto* corert = context_->GetCoreRuntime();
1268   auto* chain = context_->GetChain();
1269   auto* host = corert->GetHostContext();
1270   SmallVector<TensorHandle, 8> th_args;
1271   th_args.reserve(args_.size());
1272 
1273   SmallVector<TensorHandle, 8> result_ths;
1274   result_ths.resize(*num_retvals);
1275 
1276   if (function_state_) {
1277     // Set up arguments. Check argument dtype synchronously if available.
1278     auto arg_types = function_state_->GetArgTypes();
1279     if (args_.size() != arg_types.size()) {
1280       return tensorflow::errors::InvalidArgument("Expects ", arg_types.size(),
1281                                                  " arguments, but ",
1282                                                  args_.size(), " is provided");
1283     }
1284     auto args_size = args_.size();
1285     for (auto i = 0; i < args_size; ++i) {
1286       th_args.push_back(down_cast<TensorHandleInterface*>(args_[i].get())
1287                             ->Handle()
1288                             .CopyRef());
1289       // TODO(b/173556766): This dtype check is only needed for corert lowering.
1290       // In native lowering, compiler should obtain the argument dtype
1291       // information from FunctionBody directly and lower the op to the native
1292       // kernel that accepts the specified dtype.
1293       if (th_args[i].IsMetadataAvailable()) {
1294         auto arg_dtype = th_args[i].GetAvailableMetadata().dtype;
1295         if (arg_dtype != arg_types[i]) {
1296           return tensorflow::errors::InvalidArgument(
1297               "Expects arg[", i, "] to be ", arg_types[i], " but ", arg_dtype,
1298               " is provided");
1299         }
1300       }
1301     }
1302 
1303     RCReference<RequestContext> request_ctx;
1304     TF_RETURN_IF_ERROR(context_->BuildFunctionRequestContext(
1305         function_state_->GetRunnerTable(), &request_ctx));
1306 
1307     ExecutionContext exec_ctx{std::move(request_ctx),
1308                               abort_location_handler_.GetCurrentLocation()};
1309     // Execute the function.
1310     function_state_->GetFunc()(exec_ctx, th_args, OpAttrsRef(attrs_),
1311                                result_ths, chain);
1312   } else {
1313     RCReference<RequestContext> request_ctx;
1314     TF_RETURN_IF_ERROR(context_->BuildOpRequestContext(&request_ctx));
1315 
1316     ExecutionContext exec_ctx{std::move(request_ctx),
1317                               abort_location_handler_.GetCurrentLocation()};
1318     for (auto& arg : args_) {
1319       th_args.push_back(
1320           down_cast<TensorHandleInterface*>(arg.get())->Handle().CopyRef());
1321     }
1322     // If the CoreRuntimeOp is a native TFRT op, transfer arguments to target
1323     // device if necessary.
1324     if (!op_->IsFallback()) {
1325       // Get the target device of the arguments that we want to implicitly copy
1326       // to.
1327       auto dst_device_ref = op_->GetDeviceRef();
1328 
1329       for (auto& th_arg : th_args) {
1330         th_arg = th_arg.TransferTo(exec_ctx, dst_device_ref.CopyRef(),
1331                                    op_->GetTensorType());
1332       }
1333     }
1334 
1335     (*op_)(exec_ctx, th_args, OpAttrsRef(attrs_), result_ths, chain);
1336   }
1337 
1338   tensorflow::Status s = tensorflow::Status::OK();
1339 
1340   if (TF_PREDICT_FALSE(!this->context_->is_async() && !chain->IsAvailable()))
1341     host->Await({chain->CopyRCRef()});
1342 
1343   if (TF_PREDICT_FALSE(chain->IsError())) {
1344     s = CreateTfErrorStatus(chain->GetError());
1345     // TODO(tfrt-devs): Assess if we need a explicit API to clear error.
1346     *chain = GetReadyChain(host);
1347   }
1348 
1349   for (size_t i = 0, e = result_ths.size(); i != e; ++i) {
1350     auto& th_ref = result_ths[i];
1351     if (TF_PREDICT_FALSE(!this->context_->is_async() &&
1352                          !th_ref.GetAsyncTensor()->IsAvailable()))
1353       host->Await(FormRef(th_ref.GetAsyncTensor()));
1354 
1355     // NOTE(fishx): In async mode, we won't report error synchronously even
1356     // though it is possible in TFRT. This is intended to match behavior in
1357     // current TF. However, in the future, we may want to update this
1358     // behavior since synchronous error may improve user experience in async
1359     // mode.
1360     if (TF_PREDICT_FALSE(!this->context_->is_async() &&
1361                          th_ref.GetAsyncTensor()->IsError() && s.ok()))
1362       s = CreateTfErrorStatus(th_ref.GetAsyncTensor()->GetError());
1363 
1364     retvals[i] =
1365         new TensorHandleInterface(Value(std::move(result_ths[i])), corert);
1366   }
1367 
1368   return s;
1369 }
1370 
Initialize()1371 tensorflow::Status OperationInterface::Initialize() {
1372   CoreRuntime* corert = context_->GetCoreRuntime();
1373   if (!is_function_) {
1374     // Obtain input arguments' dtype attrs as part of the cache key.
1375     llvm::SmallVector<string_view, 4> dtypes;
1376     attrs_.IterateEntries([&](const OpAttrsRawEntry& entry) {
1377       if (entry.type == OpAttrType::DTYPE && !entry.IsArray())
1378         dtypes.push_back(
1379             GetNameString(*static_cast<const OpAttrType*>(entry.GetData())));
1380     });
1381 
1382     OpHandler* op_handler = nullptr;
1383     TF_RETURN_IF_ERROR(
1384         context_->SelectOpHandlerFromArguments(*this, &op_handler));
1385     Expected<CoreRuntimeOp*> expected_op = context_->GetOpCache().GetOrAddOp(
1386         op_name_, op_handler, device_name_, dtypes, this);
1387     if (!expected_op) {
1388       return tensorflow::errors::InvalidArgument(
1389           StrCat("Cannot obtain CoreRuntimeOp: ", op_name_,
1390                  " on device: ", device_name_, expected_op.takeError()));
1391     }
1392     op_ = expected_op.get();
1393     // Update device name since op_handler_selecter may choose an op_handler
1394     // that's different from what the user specifies.
1395     device_name_ = op_->DeviceName().str();
1396     return tensorflow::Status::OK();
1397   }
1398 
1399   bool compile_with_xla = false;
1400   GetFuncAttr(attrs_, op_name_, *context_->GetEagerContext()->FuncLibDef(),
1401               tensorflow::kXlaMustCompileAttr, &compile_with_xla);
1402   // If the function has compile_with_xla==true, we will use RuntimeFallback
1403   // to execute it, since TFRT does not support xla yet.
1404   // TODO(tfrt-devs): Native support of compile_with_xla.
1405   if (compile_with_xla) {
1406     Expected<CoreRuntimeOp*> expected_op =
1407         context_->GetOpCache().GetOrAddXlaOp(op_name_, context_);
1408     if (!expected_op) {
1409       return tensorflow::errors::NotFound(
1410           StrCat("Cannot initialize xla function ", op_name_,
1411                  " on fallback op handler.", expected_op.takeError()));
1412     }
1413     op_ = expected_op.get();
1414     return tensorflow::Status::OK();
1415   }
1416 
1417   // Note(fishx): We need eager context for now because we need
1418   // FunctionLibraryDefinition to convert FunctionDef to MLIR TF dialect. In
1419   // the future, when we can generate MLIR from TF Python, we should get rid of
1420   // this.
1421   // FunctionDef -> BEF.
1422   // Look up the cache. Compile BEF and insert to cache if miss.
1423   tensorflow::DeviceSet dev_set;
1424   const DeviceMgr* device_mgr = context_->GetEagerContext()->local_device_mgr();
1425   if (device_mgr == nullptr)
1426     return tensorflow::errors::NotFound("Cannot find device manager");
1427   // TODO(tfrt-devs): support remote devices in TFRT.
1428   for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
1429   if (context_->GetDistributedManager() != nullptr &&
1430       context_->UseTfrtDistributedRuntime()) {
1431     down_cast<DistributedManagerContextInterface*>(
1432         context_->GetDistributedManager())
1433         ->PopulateRemoteDevices(&dev_set);
1434   }
1435   FunctionCache::FunctionCacheResult result;
1436 
1437   tensorflow::TfrtFunctionCompileOptions compile_options;
1438   // TODO(tfrt-devs): Formalize the convention of converting TF's device name
1439   // to MLIR compilation option's device name.
1440   std::string tf_device_name = device_name_;
1441   compile_options.default_device = std::move(tf_device_name);
1442   // TODO(b/172659131): Do not use TFRT native ops for TF integration for now.
1443   // Re-enable once we have a concrete plan to implement feature complete
1444   // TFRT native ops (kernels).
1445   compile_options.enable_native_ops = false;
1446 
1447   if (fallback_attrs_.NumAttributes() > 0) {
1448     const auto& ndef = NodeDef();
1449     // TODO(tfrt-devs): If we are to create more attributes, consider packing
1450     // them into a proto.
1451     {
1452       const auto& it = ndef.attr().find(kEnableNativeOpsAttr);
1453       if (it != ndef.attr().end()) {
1454         compile_options.enable_native_ops = it->second.b();
1455       }
1456     }
1457 
1458     {
1459       const auto& it = ndef.attr().find(kEnableGrapplerAttr);
1460       if (it != ndef.attr().end()) {
1461         compile_options.enable_grappler = it->second.b();
1462       }
1463     }
1464   }
1465 
1466   tfrt::SmallVector<const tfrt::Device*, 4> input_devices;
1467   input_devices.reserve(args_.size());
1468   for (auto& arg : args_) {
1469     input_devices.push_back(down_cast<TensorHandleInterface*>(arg.get())
1470                                 ->Handle()
1471                                 .GetAvailableDevice()
1472                                 .get());
1473   }
1474   TF_RETURN_IF_ERROR(context_->GetFunctionCache().GetOrAddFunction(
1475       op_name_, device_name_, dev_set, context_->GetEagerContext(), corert,
1476       /*request_ctx_fn=*/
1477       [this](tensorflow::tfd::OpKernelRunnerTable* runner_table,
1478              RCReference<RequestContext>* request_ctx) {
1479         return context_->BuildFunctionRequestContext(runner_table, request_ctx);
1480       },
1481       abort_location_handler_.GetCurrentLocation(), compile_options,
1482       input_devices, &result));
1483   // TODO(tfrt-devs): Avoid calling EagerContext::ShouldStoreGraphs().
1484   if (result.is_cache_miss &&
1485       context_->GetEagerContext()->ShouldStoreGraphs()) {
1486     TF_RETURN_IF_ERROR(context_->RunMetadataRecordFunction(op_name_));
1487   }
1488   function_state_ = std::move(result.function_state);
1489   return tensorflow::Status::OK();
1490 }
1491 
SetDeviceName(const char * name)1492 tensorflow::Status OperationInterface::SetDeviceName(const char* name) {
1493   if (op_ && name != device_name_) {
1494     return tensorflow::errors::Internal(
1495         "Failed to update device name. Right now TFRT cannot update device "
1496         "name of a fallback op if it is initialized.");
1497   }
1498   device_name_ = name ? name : "";
1499   return tensorflow::Status::OK();
1500 }
1501 
AddInput(tensorflow::AbstractTensorHandle * input)1502 tensorflow::Status OperationInterface::AddInput(
1503     tensorflow::AbstractTensorHandle* input) {
1504   tensorflow::ImmediateExecutionTensorHandle* h =
1505       down_cast<tensorflow::ImmediateExecutionTensorHandle*>(input);
1506   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1507   if (tensorflow::CustomDeviceTensorHandle::classof(h)) {
1508     custom_device_tensor_handle_count_++;
1509   }
1510   h->Ref();
1511   args_.push_back(
1512       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1513           h));
1514   return tensorflow::Status::OK();
1515 }
1516 
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)1517 tensorflow::Status OperationInterface::SetInput(
1518     size_t index, tensorflow::ImmediateExecutionTensorHandle* input) {
1519   if (index >= args_.size()) {
1520     return tensorflow::errors::InvalidArgument("Index >= inputs.size: %d >= %d",
1521                                                index, args_.size());
1522   }
1523   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1524   if (tensorflow::CustomDeviceTensorHandle::classof(args_[index].get())) {
1525     custom_device_tensor_handle_count_--;
1526   }
1527   if (tensorflow::CustomDeviceTensorHandle::classof(input)) {
1528     custom_device_tensor_handle_count_++;
1529   }
1530   input->Ref();
1531   args_[index] =
1532       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1533           input);
1534   return tensorflow::Status::OK();
1535 }
1536 
AddInputList(absl::Span<tensorflow::AbstractTensorHandle * const> inputs)1537 tensorflow::Status OperationInterface::AddInputList(
1538     absl::Span<tensorflow::AbstractTensorHandle* const> inputs) {
1539   return tensorflow::errors::Unimplemented(
1540       "Unimplemented OperationInterface::AddInputList");
1541 }
1542 
1543 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const>
GetInputs() const1544 OperationInterface::GetInputs() const {
1545   return absl::MakeSpan(
1546       reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
1547           args_.data()),
1548       args_.size());
1549 }
1550 
SetAttrString(const char * attr_name,const char * data,size_t length)1551 tensorflow::Status OperationInterface::SetAttrString(const char* attr_name,
1552                                                      const char* data,
1553                                                      size_t length) {
1554   fallback_attrs_.Set(attr_name, tensorflow::StringPiece(data, length));
1555   if (attrs_.SetString(attr_name, string_view(data, length)))
1556     return tensorflow::Status::OK();
1557   return tensorflow::errors::Internal(
1558       "OperationInterface::SetAttrString failed");
1559 }
1560 
SetAttrInt(const char * attr_name,int64_t value)1561 tensorflow::Status OperationInterface::SetAttrInt(const char* attr_name,
1562                                                   int64_t value) {
1563   fallback_attrs_.Set(attr_name, static_cast<int64_t>(value));
1564   if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1565   return tensorflow::errors::Internal("OperationInterface::SetAttrInt failed");
1566 }
1567 
SetAttrFloat(const char * attr_name,float value)1568 tensorflow::Status OperationInterface::SetAttrFloat(const char* attr_name,
1569                                                     float value) {
1570   fallback_attrs_.Set(attr_name, value);
1571   if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1572   return tensorflow::errors::Internal(
1573       "OperationInterface::SetAttrFloat failed");
1574 }
1575 
SetAttrBool(const char * attr_name,bool value)1576 tensorflow::Status OperationInterface::SetAttrBool(const char* attr_name,
1577                                                    bool value) {
1578   fallback_attrs_.Set(attr_name, value);
1579   if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1580   return tensorflow::errors::Internal("OperationInterface::SetAttrBool failed");
1581 }
1582 
SetAttrType(const char * attr_name,tensorflow::DataType value)1583 tensorflow::Status OperationInterface::SetAttrType(const char* attr_name,
1584                                                    tensorflow::DataType value) {
1585   fallback_attrs_.Set(attr_name, value);
1586   if (value == tensorflow::DT_INVALID) {
1587     return tensorflow::errors::InvalidArgument(
1588         "OperationInterface::SetAttrType failed to set DT_INVALID");
1589   }
1590   if (attrs_.Set(attr_name,
1591                  tfrt::GetOpAttrTypeFromDType(
1592                      tensorflow::tfd::ConvertTfDataTypeToBefAttrType(value))))
1593     return tensorflow::Status::OK();
1594   // TODO(fishx): Remove this workaround once we support all dtype in TF.
1595   // This is fine for now since attribute "T", "U", "Tidx" is not used by TFRT
1596   // native ops.
1597   if (std::strcmp(attr_name, "T") == 0 || std::strcmp(attr_name, "U") == 0 ||
1598       std::strcmp(attr_name, "Tidx") == 0) {
1599     return tensorflow::Status::OK();
1600   }
1601   return tensorflow::errors::Internal("OperationInterface::SetAttrType failed");
1602 }
1603 
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)1604 tensorflow::Status OperationInterface::SetAttrShape(const char* attr_name,
1605                                                     const int64_t* dims,
1606                                                     const int num_dims) {
1607   // NOTE: This is copied from EagerOperation::SetAttrShape.
1608   // TODO(b/154554118): Remove the duplication.
1609   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1610     return tensorflow::errors::InvalidArgument(
1611         "Value specified for `", attr_name, "` has ", num_dims,
1612         " dimensions which is over the limit of ",
1613         tensorflow::TensorShape::MaxDimensions(), ".");
1614   }
1615 
1616   tensorflow::TensorShapeProto proto;
1617   size_t offset;
1618   if (num_dims < 0) {
1619     proto.set_unknown_rank(true);
1620 
1621     // Set unranked ShapeAttr.
1622     offset = bef_attr_encoder_.EncodeUnrankedShapeAttr();
1623   } else {
1624     for (int d = 0; d < num_dims; ++d) {
1625       proto.add_dim()->set_size(dims[d]);
1626     }
1627 
1628     // Set RankedShapeAttr.
1629     offset = bef_attr_encoder_.EncodeRankedShapeAttr(
1630         llvm::makeArrayRef(dims, num_dims));
1631   }
1632   fallback_attrs_.Set(attr_name, proto);
1633 
1634   auto buf = bef_attr_encoder_.TakeResult();
1635   tfrt::ShapeAttr shape_attr(buf.data() + offset);
1636   // TODO(tfrt-devs): Avoid the copy.
1637   if (attrs_.Set(attr_name, shape_attr)) return tensorflow::Status::OK();
1638 
1639   return tensorflow::errors::Internal(
1640       "OperationInterface::SetAttrShape failed");
1641 }
1642 
SetAttrFunction(const char * attr_name,const tensorflow::AbstractOperation * value)1643 tensorflow::Status OperationInterface::SetAttrFunction(
1644     const char* attr_name, const tensorflow::AbstractOperation* value) {
1645   auto* value_operation = down_cast<const OperationInterface*>(value);
1646   // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1647   // Consider removing this and rely on TFRT OpAttrs.
1648   tensorflow::AttrValue attr_value;
1649   tensorflow::NameAttrList* func = attr_value.mutable_func();
1650   func->set_name(value->Name());
1651   fallback_attrs_.Set(attr_name, attr_value);
1652 
1653   if (attrs_.SetFunc(attr_name, {string_view(value_operation->Name())}))
1654     return tensorflow::Status::OK();
1655 
1656   return tensorflow::errors::Internal(
1657       "OperationInterface::SetAttrFunction failed");
1658 }
1659 
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)1660 tensorflow::Status OperationInterface::SetAttrFunctionName(
1661     const char* attr_name, const char* data, size_t length) {
1662   // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1663   // Consider removing this and rely on TFRT OpAttrs.
1664   tensorflow::AttrValue attr_value;
1665   tensorflow::NameAttrList* func = attr_value.mutable_func();
1666   func->set_name(data);
1667   fallback_attrs_.Set(attr_name, attr_value);
1668 
1669   if (attrs_.SetFunc(attr_name, {data})) return tensorflow::Status::OK();
1670 
1671   return tensorflow::errors::Internal(
1672       "OperationInterface::SetAttrFunctionName failed");
1673 }
1674 
SerializeTFETensorToDenseAttr(tensorflow::AbstractTensorInterface * tensor,tfrt::BefAttrEncoder * encoder)1675 static size_t SerializeTFETensorToDenseAttr(
1676     tensorflow::AbstractTensorInterface* tensor,
1677     tfrt::BefAttrEncoder* encoder) {
1678   std::vector<uint8_t> data;
1679 
1680   const auto element_type =
1681       tensorflow::tfd::ConvertTfDataTypeToBefAttrType(tensor->Type());
1682   SmallVector<int64_t, 4> shape;
1683   for (int i = 0; i < tensor->NumDims(); ++i) {
1684     shape.push_back(tensor->Dim(i));
1685   }
1686   auto elements = llvm::makeArrayRef(
1687       reinterpret_cast<const uint8_t*>(tensor->Data()), tensor->ByteSize());
1688   return encoder->EncodeDenseAttr(static_cast<DType>(element_type), shape,
1689                                   elements);
1690 }
1691 
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)1692 tensorflow::Status OperationInterface::SetAttrTensor(
1693     const char* attr_name, tensorflow::AbstractTensorInterface* tensor) {
1694   tfrt::BefAttrEncoder encoder;
1695   const size_t offset = SerializeTFETensorToDenseAttr(tensor, &encoder);
1696   auto buffer = encoder.TakeResult();
1697   DenseAttr dense_attr(buffer.data() + offset);
1698   if (attrs_.Set(attr_name, dense_attr)) return tensorflow::Status::OK();
1699 
1700   return tensorflow::errors::Internal(
1701       "OperationInterface::SetAttrTensor failed");
1702 }
1703 
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1704 tensorflow::Status OperationInterface::SetAttrStringList(
1705     const char* attr_name, const void* const* values, const size_t* lengths,
1706     int num_values) {
1707   std::vector<tensorflow::StringPiece> v(num_values);
1708   for (int i = 0; i < num_values; ++i) {
1709     v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1710                                    lengths[i]);
1711   }
1712   fallback_attrs_.Set(attr_name, v);
1713 
1714   tfrt::BefAttrEncoder encoder;
1715   const size_t offset =
1716       encoder.EncodeStringListAttr(values, lengths, num_values);
1717   auto buf = encoder.TakeResult();
1718   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1719   // TODO(tfrt-devs): Avoid the copy.
1720   if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1721 
1722   return tensorflow::errors::Internal(
1723       "OperationInterface::SetAttrStringList failed");
1724 }
1725 
SetAttrFloatList(const char * attr_name,const float * values,int num_values)1726 tensorflow::Status OperationInterface::SetAttrFloatList(const char* attr_name,
1727                                                         const float* values,
1728                                                         int num_values) {
1729   fallback_attrs_.Set(
1730       attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1731 
1732   if (attrs_.SetArray(attr_name, tfrt::ArrayRef<float>(values, num_values)))
1733     return tensorflow::Status::OK();
1734   return tensorflow::errors::Internal(
1735       "OperationInterface::SetAttrFloatList failed");
1736 }
1737 
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)1738 tensorflow::Status OperationInterface::SetAttrIntList(const char* attr_name,
1739                                                       const int64_t* values,
1740                                                       int num_values) {
1741   fallback_attrs_.Set(
1742       attr_name, tensorflow::gtl::ArraySlice<const int64_t>(
1743                      reinterpret_cast<const int64_t*>(values), num_values));
1744 
1745   if (attrs_.SetArray(attr_name, tfrt::ArrayRef<int64_t>(values, num_values)))
1746     return tensorflow::Status::OK();
1747 
1748   return tensorflow::errors::Internal(
1749       "OperationInterface::SetAttrIntList failed");
1750 }
1751 
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)1752 tensorflow::Status OperationInterface::SetAttrTypeList(
1753     const char* attr_name, const tensorflow::DataType* values, int num_values) {
1754   fallback_attrs_.Set(attr_name,
1755                       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1756                           values, num_values));
1757   // Convert to OpAttrType first.
1758   SmallVector<tfrt::DType, 4> tfrt_dtypes;
1759   tfrt_dtypes.reserve(num_values);
1760   for (int i = 0; i < num_values; ++i) {
1761     tfrt_dtypes.push_back(
1762         tensorflow::tfd::ConvertTfDataTypeToBefAttrType(values[i]));
1763   }
1764 
1765   if (attrs_.SetRaw(attr_name, tfrt_dtypes.data(), tfrt::OpAttrType::DTYPE,
1766                     num_values, OpAttrsRawEntryType::kArray))
1767     return tensorflow::Status::OK();
1768 
1769   return tensorflow::errors::Internal(
1770       "OperationInterface::SetAttrTypeList failed");
1771 }
1772 
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)1773 tensorflow::Status OperationInterface::SetAttrBoolList(
1774     const char* attr_name, const unsigned char* values, int num_values) {
1775   std::unique_ptr<bool[]> b(new bool[num_values]);
1776   for (int i = 0; i < num_values; ++i) {
1777     b[i] = values[i];
1778   }
1779   fallback_attrs_.Set(
1780       attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1781 
1782   // Convert to bool first.
1783   SmallVector<bool, 4> bool_array;
1784   bool_array.reserve(num_values);
1785   for (int i = 0; i < num_values; ++i) {
1786     bool_array.push_back(static_cast<bool>((values[i])));
1787   }
1788   if (attrs_.SetArray(attr_name,
1789                       tfrt::ArrayRef<bool>(bool_array.data(), num_values)))
1790     return tensorflow::Status::OK();
1791 
1792   return tensorflow::errors::Internal(
1793       "OperationInterface::SetAttrBoolList failed");
1794 }
1795 
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)1796 tensorflow::Status OperationInterface::SetAttrShapeList(const char* attr_name,
1797                                                         const int64_t** dims,
1798                                                         const int* num_dims,
1799                                                         int num_values) {
1800   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1801       new tensorflow::TensorShapeProto[num_values]);
1802   for (int i = 0; i < num_values; ++i) {
1803     const auto num_dims_i = num_dims[i];
1804 
1805     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1806       return tensorflow::errors::InvalidArgument(
1807           StrCat("Value specified for `", attr_name, "` has ", num_dims_i,
1808                  " dimensions which is over the limit of ",
1809                  tensorflow::TensorShape::MaxDimensions(), "."));
1810     }
1811     if (num_dims_i < 0) {
1812       proto[i].set_unknown_rank(true);
1813     } else {
1814       const int64_t* dims_i = dims[i];
1815       auto proto_i = &proto[i];
1816       for (int d = 0; d < num_dims_i; ++d) {
1817         proto_i->add_dim()->set_size(dims_i[d]);
1818       }
1819     }
1820   }
1821   fallback_attrs_.Set(attr_name,
1822                       tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1823                           proto.get(), num_values));
1824 
1825   BefAttrEncoder encoder;
1826   const size_t offset = encoder.EncodeShapeListAttr(dims, num_dims, num_values);
1827   auto buf = encoder.TakeResult();
1828   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1829   if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1830 
1831   return tensorflow::errors::Internal(
1832       "OperationInterface::SetAttrShapeList failed");
1833 }
1834 
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)1835 tensorflow::Status OperationInterface::SetAttrFunctionList(
1836     const char* attr_name, absl::Span<const AbstractOperation*> values) {
1837   size_t num_values = values.size();
1838   std::vector<const void*> func_attrs(num_values);
1839   std::vector<size_t> lengths(num_values);
1840 
1841   for (int i = 0; i < num_values; ++i) {
1842     auto* value_operation = down_cast<const OperationInterface*>(values[i]);
1843     lengths[i] = value_operation->Name().length();
1844     func_attrs[i] = value_operation->Name().c_str();
1845   }
1846 
1847   // Encode the array of function attributes with BEF typed attribute encoder to
1848   // an aggregated attribute.
1849   BefAttrEncoder encoder;
1850   const size_t offset =
1851       encoder.EncodeFuncListAttr(func_attrs.data(), lengths.data(), num_values);
1852   auto buf = encoder.TakeResult();
1853   tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1854   if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1855 
1856   return tensorflow::errors::Internal(
1857       "OperationInterface::SetAttrFunctionList failed");
1858 }
1859 
InputLength(const char * input_name,int * length)1860 tensorflow::Status OperationInterface::InputLength(const char* input_name,
1861                                                    int* length) {
1862   return tensorflow::errors::Unimplemented(
1863       "Unimplemented OperationInterface::InputLength");
1864 }
1865 
OutputLength(const char * output_name,int * length)1866 tensorflow::Status OperationInterface::OutputLength(const char* output_name,
1867                                                     int* length) {
1868   return tensorflow::errors::Unimplemented(
1869       "Unimplemented OperationInterface::OutputLength");
1870 }
1871 
GetOpAttrs() const1872 const tensorflow::AbstractOpAttrs* OperationInterface::GetOpAttrs() const {
1873   return &op_attrs_;
1874 }
1875 
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)1876 void OperationInterface::AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) {
1877   auto* tfrt_op_attrs = down_cast<const OpAttrsInterface*>(op_attrs);
1878   tfrt_op_attrs->GetAttrs()->IterateEntries(
1879       [this](const OpAttrsRawEntry& entry) {
1880         attrs_.SetRaw(entry.name, entry.GetData(), entry.type,
1881                       entry.element_count, entry.entry_type);
1882       });
1883   fallback_attrs_.CopyAttributes(*tfrt_op_attrs->GetFallbackAttrs());
1884 }
1885 
MaybeInferInputAttrs()1886 void OperationInterface::MaybeInferInputAttrs() {
1887   if (!op_def_) return;
1888   for (int i = 0; i < args_.size(); i++) {
1889     auto* handle = args_[i].get();
1890     const auto& input_def = op_def_->input_arg(i);
1891     if (!input_def.number_attr().empty() ||
1892         !input_def.type_list_attr().empty()) {
1893       // Some clients that are still setting their input attributes manually are
1894       // adding input list to their op by calling `TFE_OpAddInput` for each of
1895       // its elements instead of calling `TFE_OpAddInputList`. When this
1896       // happens, we cannot detect the end of such list, thus lose track of the
1897       // input arguments in the op definition. To guarantee backward
1898       // compatibility with those clients, disable automatic inference in this
1899       // case.
1900       return;
1901     }
1902     const std::string& type_attr = input_def.type_attr();
1903     if (!type_attr.empty()) {
1904       bool success = attrs_.Set(
1905           type_attr, tfrt::GetOpAttrTypeFromDType(
1906                          tensorflow::tfd::ConvertTfDataTypeToBefAttrType(
1907                              handle->DataType())));
1908       if (success) {
1909         fallback_attrs_.Set(type_attr, handle->DataType());
1910       }
1911     }
1912   }
1913 }
1914 
1915 }  // namespace tf
1916 }  // namespace tfrt
1917