1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
17
18 #include <cstddef>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/immediate_execution_context.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tf_tensor_internal.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
30 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/casts.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/stringpiece.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/error_codes.pb.h"
51 #include "tensorflow/core/public/session_options.h"
52 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
53 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
54 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
55 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
56 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
57 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
58 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_interface.h"
59 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_registry.h"
60 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
61 #include "tensorflow/core/tfrt/eager/virtual_device.h"
62 #include "tensorflow/core/tfrt/utils/error_util.h"
63 #include "tensorflow/core/tfrt/utils/utils.h"
64 #include "tensorflow/core/util/device_name_utils.h"
65 #include "tfrt/common/compat/eigen/eigen_dtype.h" // from @tf_runtime
66 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
67 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
68 #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime
69 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
70 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
71 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
72 #include "tfrt/dtype/dtype.h" // from @tf_runtime
73 #include "tfrt/host_context/async_value.h" // from @tf_runtime
74 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
75 #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
76 #include "tfrt/host_context/chain.h" // from @tf_runtime
77 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
78 #include "tfrt/host_context/device.h" // from @tf_runtime
79 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
80 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
81 #include "tfrt/host_context/function.h" // from @tf_runtime
82 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
83 #include "tfrt/host_context/host_context.h" // from @tf_runtime
84 #include "tfrt/host_context/location.h" // from @tf_runtime
85 #include "tfrt/host_context/resource_context.h" // from @tf_runtime
86 #include "tfrt/metrics/common_metrics.h" // from @tf_runtime
87 #include "tfrt/support/error_util.h" // from @tf_runtime
88 #include "tfrt/support/forward_decls.h" // from @tf_runtime
89 #include "tfrt/support/ref_count.h" // from @tf_runtime
90 #include "tfrt/support/string_util.h" // from @tf_runtime
91 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
92 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
93 #include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime
94 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
95 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
96 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
97 #include "tfrt/tensor/tensor_serialize_utils.h" // from @tf_runtime
98 #include "tfrt/tensor/tensor_type_registration.h" // from @tf_runtime
99
100 namespace tfrt {
101 namespace tf {
102
103 namespace {
104
105 using tensorflow::down_cast;
106
107 constexpr char kGpuDeviceName[] = "GPU";
108 constexpr char kEnableNativeOpsAttr[] = "TFRT_TEST_enable_native_ops";
109 constexpr char kEnableGrapplerAttr[] = "TFRT_TEST_enable_grappler";
110
CreateMetadata(DType dtype,absl::Span<const Index> dim_sizes)111 TensorMetadata CreateMetadata(DType dtype, absl::Span<const Index> dim_sizes) {
112 return TensorMetadata(
113 DType(dtype),
114 TensorShape(llvm::ArrayRef<Index>(
115 reinterpret_cast<const Index*>(dim_sizes.data()), dim_sizes.size())));
116 }
117
ConvertDType(DType kind)118 tensorflow::DataType ConvertDType(DType kind) {
119 switch (kind) {
120 case DType::UI8:
121 return tensorflow::DT_UINT8;
122 case DType::UI16:
123 return tensorflow::DT_UINT16;
124 case DType::UI32:
125 return tensorflow::DT_UINT32;
126 case DType::UI64:
127 return tensorflow::DT_UINT64;
128 case DType::I8:
129 return tensorflow::DT_INT8;
130 case DType::I16:
131 return tensorflow::DT_INT16;
132 case DType::I32:
133 return tensorflow::DT_INT32;
134 case DType::I64:
135 return tensorflow::DT_INT64;
136 case DType::BF16:
137 return tensorflow::DT_BFLOAT16;
138 case DType::F16:
139 return tensorflow::DT_HALF;
140 case DType::F32:
141 return tensorflow::DT_FLOAT;
142 case DType::F64:
143 return tensorflow::DT_DOUBLE;
144 case DType::I1:
145 return tensorflow::DT_BOOL;
146 case DType::Complex64:
147 return tensorflow::DT_COMPLEX64;
148 case DType::Complex128:
149 return tensorflow::DT_COMPLEX128;
150 case DType::String:
151 return tensorflow::DT_STRING;
152 case DType::Resource:
153 return tensorflow::DT_RESOURCE;
154 case DType::Variant:
155 return tensorflow::DT_VARIANT;
156 case DType::QUI8:
157 return tensorflow::DT_QUINT8;
158 case DType::QUI16:
159 return tensorflow::DT_QUINT16;
160 case DType::QI8:
161 return tensorflow::DT_QINT8;
162 case DType::QI16:
163 return tensorflow::DT_QINT16;
164 case DType::QI32:
165 return tensorflow::DT_QINT32;
166 default:
167 LOG(ERROR) << "Unsupported kind " << kind;
168 return tensorflow::DT_INVALID;
169 }
170 }
171
ConvertDType(tensorflow::DataType dtype)172 DType ConvertDType(tensorflow::DataType dtype) {
173 switch (dtype) {
174 case tensorflow::DT_UINT8:
175 return static_cast<DType>(DType::UI8);
176 case tensorflow::DT_UINT16:
177 return static_cast<DType>(DType::UI16);
178 case tensorflow::DT_UINT32:
179 return static_cast<DType>(DType::UI32);
180 case tensorflow::DT_UINT64:
181 return static_cast<DType>(DType::UI64);
182 case tensorflow::DT_INT8:
183 return static_cast<DType>(DType::I8);
184 case tensorflow::DT_INT16:
185 return static_cast<DType>(DType::I16);
186 case tensorflow::DT_INT32:
187 return static_cast<DType>(DType::I32);
188 case tensorflow::DT_INT64:
189 return static_cast<DType>(DType::I64);
190 case tensorflow::DT_BFLOAT16:
191 return static_cast<DType>(DType::BF16);
192 case tensorflow::DT_HALF:
193 return static_cast<DType>(DType::F16);
194 case tensorflow::DT_FLOAT:
195 return static_cast<DType>(DType::F32);
196 case tensorflow::DT_DOUBLE:
197 return static_cast<DType>(DType::F64);
198 case tensorflow::DT_BOOL:
199 return static_cast<DType>(DType::I1);
200 case tensorflow::DT_STRING:
201 return static_cast<DType>(DType::String);
202 case tensorflow::DT_COMPLEX64:
203 return static_cast<DType>(DType::Complex64);
204 case tensorflow::DT_COMPLEX128:
205 return static_cast<DType>(DType::Complex128);
206 case tensorflow::DT_RESOURCE:
207 return static_cast<DType>(DType::Resource);
208 case tensorflow::DT_VARIANT:
209 return static_cast<DType>(DType::Variant);
210 case tensorflow::DT_QUINT8:
211 return static_cast<DType>(DType::QUI8);
212 case tensorflow::DT_QUINT16:
213 return static_cast<DType>(DType::QUI16);
214 case tensorflow::DT_QINT8:
215 return static_cast<DType>(DType::QI8);
216 case tensorflow::DT_QINT16:
217 return static_cast<DType>(DType::QI16);
218 case tensorflow::DT_QINT32:
219 return static_cast<DType>(DType::QI32);
220 default:
221 LOG(FATAL) << "Unsupported dtype " << dtype;
222 }
223 }
224
ConvertDTypeToOpAttrType(tensorflow::DataType dtype)225 OpAttrType ConvertDTypeToOpAttrType(tensorflow::DataType dtype) {
226 switch (dtype) {
227 case tensorflow::DT_UINT8:
228 return OpAttrType::UI8;
229 case tensorflow::DT_UINT16:
230 return OpAttrType::UI16;
231 case tensorflow::DT_UINT32:
232 return OpAttrType::UI32;
233 case tensorflow::DT_UINT64:
234 return OpAttrType::UI64;
235 case tensorflow::DT_INT8:
236 return OpAttrType::I8;
237 case tensorflow::DT_INT16:
238 return OpAttrType::I16;
239 case tensorflow::DT_INT32:
240 return OpAttrType::I32;
241 case tensorflow::DT_INT64:
242 return OpAttrType::I64;
243 case tensorflow::DT_BFLOAT16:
244 return OpAttrType::BF16;
245 case tensorflow::DT_HALF:
246 return OpAttrType::F16;
247 case tensorflow::DT_FLOAT:
248 return OpAttrType::F32;
249 case tensorflow::DT_DOUBLE:
250 return OpAttrType::F64;
251 case tensorflow::DT_BOOL:
252 return OpAttrType::BOOL;
253 case tensorflow::DT_COMPLEX64:
254 return OpAttrType::COMPLEX64;
255 case tensorflow::DT_COMPLEX128:
256 return OpAttrType::COMPLEX128;
257 default:
258 LOG(FATAL) << "Unsupported dtype " << dtype;
259 }
260 }
261
262 // This method will first look at the calling op attrs and then look at the
263 // function def attrs to find the attribute value.
GetFuncAttr(const OpAttrs & op_attrs,const std::string & op_name,const tensorflow::FunctionLibraryDefinition & func_lib_def,string_view attr_name,bool * value)264 void GetFuncAttr(const OpAttrs& op_attrs, const std::string& op_name,
265 const tensorflow::FunctionLibraryDefinition& func_lib_def,
266 string_view attr_name, bool* value) {
267 bool success = op_attrs.Get(attr_name, value);
268 if (success) {
269 DVLOG(2) << "Caller explicitly specifies " << attr_name.str()
270 << (value ? "=true " : "=false, ");
271 return;
272 }
273
274 const tensorflow::FunctionDef* function_def = func_lib_def.Find(op_name);
275 if (function_def == nullptr) {
276 return;
277 }
278
279 tensorflow::Status status =
280 GetNodeAttr(tensorflow::AttrSlice(&function_def->attr()),
281 {attr_name.data(), attr_name.size()}, value);
282 if (status.ok()) {
283 DVLOG(2) << "Function definition explicitly specifies " << attr_name.str()
284 << (value ? "=true" : "=false");
285 return;
286 }
287 }
288
GetNextLocationId()289 int64_t GetNextLocationId() {
290 static std::atomic<int64_t> id(0);
291 return id.fetch_add(1, std::memory_order_relaxed);
292 }
293 } // namespace
294
Type() const295 tensorflow::DataType TensorInterface::Type() const {
296 auto kind = tensor_.get().metadata().dtype;
297 if (kind == DType::Unsupported) {
298 assert(llvm::isa<tensorflow::tfd::RuntimeFallbackTensor>(tensor_.get()));
299 return tensor_.get<tensorflow::tfd::RuntimeFallbackTensor>()
300 .GetTensorHandle()
301 ->DataType();
302 }
303 return ConvertDType(kind);
304 }
305
NumDims() const306 int TensorInterface::NumDims() const { return tensor_.get().shape().GetRank(); }
307
Dim(int dim_index) const308 int64_t TensorInterface::Dim(int dim_index) const {
309 return tensor_.get().shape().GetDimensionSize(dim_index);
310 }
311
NumElements() const312 int64_t TensorInterface::NumElements() const {
313 if (!tensor_) {
314 return static_cast<int64_t>(tf_tensor_.NumElements());
315 }
316 return tensor_.get().shape().GetNumElements();
317 }
318
ByteSize() const319 size_t TensorInterface::ByteSize() const {
320 return tensor_.get().metadata().GetHostSizeInBytes();
321 }
322
Data() const323 void* TensorInterface::Data() const {
324 if (!tensor_) {
325 return tensorflow::TensorCApi::Buffer(tf_tensor_)->data();
326 } else {
327 auto& tensor = tensor_.get<DenseHostTensor>();
328 return tensor.data();
329 }
330 }
331
332 // TFRT DenseHostTensor is always aligned
IsAligned() const333 bool TensorInterface::IsAligned() const { return true; }
334
CanMove() const335 bool TensorInterface::CanMove() const {
336 // It is safe to move the Tensor if and only if we own the unique reference to
337 // the tensor buffer.
338 auto& dht = tensor_.get<DenseHostTensor>();
339 return tensor_.IsUnique() && dht.buffer()->IsUnique();
340 }
341
SummarizeValue() const342 std::string TensorInterface::SummarizeValue() const {
343 if (!tensor_) {
344 return tf_tensor_.SummarizeValue(/*max_entries=*/3, /*print_v2=*/true);
345 } else {
346 std::string result;
347 llvm::raw_string_ostream result_ostream(result);
348 tensor_->Print(result_ostream);
349 return result;
350 }
351 }
352
TensorRef() const353 AsyncValueRef<Tensor> TensorInterface::TensorRef() const {
354 return tensor_.CopyRef();
355 }
356
TensorHandleInterface(Value && v,TfrtContext * context)357 TensorHandleInterface::TensorHandleInterface(Value&& v, TfrtContext* context)
358 : ImmediateExecutionTensorHandle(kTfrt),
359 context_(*context),
360 value_(std::move(v)) {}
361
TensorHandleInterface(tensorflow::DataType dtype,Value && v,TfrtContext * context)362 TensorHandleInterface::TensorHandleInterface(tensorflow::DataType dtype,
363 Value&& v, TfrtContext* context)
364 : ImmediateExecutionTensorHandle(kTfrt),
365 dtype_(dtype),
366 context_(*context),
367 value_(std::move(v)) {}
368
DataType() const369 tensorflow::DataType TensorHandleInterface::DataType() const {
370 // If dtype_ field is set, use it instead of waiting for the underlying
371 // TensorHandle's metadata to be available.
372 if (dtype_) {
373 return dtype_.getValue();
374 }
375 auto metadata = Metadata();
376 if (!metadata.hasValue()) {
377 LOG(ERROR)
378 << "Failed to get DataType due to error metadata: "
379 << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
380 return tensorflow::DT_INVALID;
381 }
382 auto kind = metadata.getValue()->dtype;
383 if (kind == DType::Unsupported) {
384 AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
385 if (!async_tensor->IsAvailable()) {
386 context_.GetHostContext()->Await(FormRef(async_tensor));
387 }
388
389 if (async_tensor->IsError()) {
390 LOG(ERROR) << "Failed to get DataType from an error tensor "
391 << async_tensor->GetError().message;
392 return tensorflow::DT_INVALID;
393 }
394 assert(async_tensor->IsType<tensorflow::tfd::RuntimeFallbackTensor>());
395 return async_tensor->get<tensorflow::tfd::RuntimeFallbackTensor>()
396 .GetTensorHandle()
397 ->DataType();
398 }
399 return ConvertDType(kind);
400 }
401
TensorHandleStatus() const402 tensorflow::Status TensorHandleInterface::TensorHandleStatus() const {
403 if (context_.IsAsync()) {
404 return ::tensorflow::OkStatus();
405 } else {
406 auto metadata = Metadata();
407 if (!metadata.hasValue()) {
408 LOG(ERROR)
409 << "Metadata in the tensor handle is an error metadata: "
410 << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
411 return tensorflow::errors::Internal(
412 value_.get<TensorHandle>().GetAsyncMetadata().GetError().message);
413 }
414
415 AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
416 if (!async_tensor->IsAvailable()) {
417 context_.GetHostContext()->Await(FormRef(async_tensor));
418 }
419
420 if (async_tensor->IsError()) {
421 LOG(ERROR) << "Async tensor in the tensor handle is an error tensor: "
422 << async_tensor->GetError().message;
423 return tensorflow::errors::Internal(async_tensor->GetError().message);
424 }
425
426 return ::tensorflow::OkStatus();
427 }
428 }
429
Shape(tensorflow::PartialTensorShape * shape) const430 tensorflow::Status TensorHandleInterface::Shape(
431 tensorflow::PartialTensorShape* shape) const {
432 auto metadata = Metadata();
433 if (!metadata.hasValue()) {
434 return CreateTfErrorStatus(
435 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
436 }
437 int num_dims = metadata.getValue()->shape.GetRank();
438 if (num_dims == -1) {
439 return ::tensorflow::OkStatus();
440 }
441 llvm::SmallVector<Index, 8> dims;
442 metadata.getValue()->shape.GetDimensions(&dims);
443 TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
444 return ::tensorflow::OkStatus();
445 }
446
NumDims(int * num_dims) const447 tensorflow::Status TensorHandleInterface::NumDims(int* num_dims) const {
448 auto metadata = Metadata();
449 if (!metadata.hasValue()) {
450 return CreateTfErrorStatus(
451 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
452 }
453 *num_dims = metadata.getValue()->shape.GetRank();
454
455 return ::tensorflow::OkStatus();
456 }
457
NumElements(int64_t * num_elements) const458 tensorflow::Status TensorHandleInterface::NumElements(
459 int64_t* num_elements) const {
460 auto metadata = Metadata();
461 if (!metadata.hasValue()) {
462 return CreateTfErrorStatus(
463 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
464 }
465 *num_elements = metadata.getValue()->shape.GetNumElements();
466
467 return ::tensorflow::OkStatus();
468 }
469
Dim(int dim_index,int64_t * dim) const470 tensorflow::Status TensorHandleInterface::Dim(int dim_index,
471 int64_t* dim) const {
472 auto metadata = Metadata();
473 if (!metadata.hasValue()) {
474 return CreateTfErrorStatus(
475 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
476 }
477 *dim = metadata.getValue()->shape.GetDimensionSize(dim_index);
478
479 return ::tensorflow::OkStatus();
480 }
481
DeviceName(tensorflow::Status * status) const482 const char* TensorHandleInterface::DeviceName(
483 tensorflow::Status* status) const {
484 auto& th = value_.get<TensorHandle>();
485 if (!th.IsDeviceAvailable()) {
486 context_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
487 }
488 if (th.IsDeviceError()) {
489 *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
490 return nullptr;
491 }
492 return th.GetAvailableDevice()->name().data();
493 }
494
BackingDeviceName(tensorflow::Status * status) const495 const char* TensorHandleInterface::BackingDeviceName(
496 tensorflow::Status* status) const {
497 return DeviceName(status);
498 }
499
DeviceType(tensorflow::Status * status) const500 const char* TensorHandleInterface::DeviceType(
501 tensorflow::Status* status) const {
502 auto& th = value_.get<TensorHandle>();
503 if (!th.IsDeviceAvailable()) {
504 context_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
505 }
506 if (th.IsDeviceError()) {
507 *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
508 return nullptr;
509 }
510 return th.GetAvailableDevice()->type().name().data();
511 }
512
Resolve(tensorflow::Status * status)513 tensorflow::AbstractTensorInterface* TensorHandleInterface::Resolve(
514 tensorflow::Status* status) {
515 auto* host_ctx = context_.GetHostContext();
516 auto host_device_ref = host_ctx->GetHostDeviceRef();
517 auto& th = value_.get<TensorHandle>();
518
519 auto tensor_av = th.GetAsyncTensor();
520 if (!tensor_av->IsAvailable()) {
521 host_ctx->Await(FormRef(tensor_av));
522 }
523 if (auto* error = tensor_av->GetErrorIfPresent()) {
524 *status = CreateTfErrorStatus(*error);
525 return nullptr;
526 }
527 assert(th.IsMetadataAvailable());
528
529 if (th.GetAsyncTensor()->get<Tensor>().tensor_type() ==
530 StringHostTensor::kTensorType) {
531 tensorflow::Tensor tf_tensor =
532 tensorflow::tfd::CopyShtToTfTensor(tensor_av->get<StringHostTensor>());
533 return new tensorflow::TensorInterface(tf_tensor);
534 }
535
536 // Convert the tensor to DenseHostTensor.
537 auto req_ctx =
538 tfrt::RequestContextBuilder(host_ctx, context_.GetResourceContext())
539 .build();
540 if (!req_ctx) {
541 *status = tensorflow::Status(
542 tensorflow::error::Code::UNKNOWN,
543 StrCat("Failed to build a RequestContext: ", req_ctx.takeError()));
544 return nullptr;
545 }
546 tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
547 auto target_th = th.TransferTo(exec_ctx, std::move(host_device_ref),
548 DenseHostTensor::kTensorType);
549
550 auto target_av = target_th.GetAsyncTensor();
551 if (!target_av->IsAvailable()) {
552 host_ctx->Await(FormRef(target_av));
553 }
554 if (target_av->IsError()) {
555 *status = tensorflow::Status(
556 tensorflow::error::Code::UNKNOWN,
557 StrCat("Cannot resolve tensor: ", target_av->GetError().message));
558 return nullptr;
559 }
560 auto host_tensor_ref = target_th.ReleaseTensorRef();
561 return new TensorInterface(std::move(host_tensor_ref));
562 }
563
Metadata() const564 llvm::Optional<const TensorMetadata*> TensorHandleInterface::Metadata() const {
565 auto& th = value_.get<TensorHandle>();
566 if (!th.IsMetadataAvailable()) {
567 context_.GetHostContext()->Await(th.GetAsyncMetadata().CopyRCRef());
568 }
569 if (th.IsMetadataError()) {
570 return llvm::None;
571 }
572 return &th.GetAvailableMetadata();
573 }
574
ContextInterface(const tensorflow::SessionOptions & opts,tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,bool is_async,bool use_tfrt_distributed_runtime)575 ContextInterface::ContextInterface(
576 const tensorflow::SessionOptions& opts,
577 tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
578 bool is_async, bool use_tfrt_distributed_runtime)
579 : ImmediateExecutionContext(kTfrt),
580 context_(opts, default_device_placement_policy, is_async),
581 use_tfrt_distributed_runtime_(use_tfrt_distributed_runtime) {
582 LOG(INFO) << "TFRT Enabled";
583 metrics::AddTFRTVersionMetric();
584
585 op_handler_selector_ = std::make_unique<EagerOpHandlerSelector>(
586 GetCoreRuntime(), GetEagerContext(), GetFallbackOpHandler(),
587 GetEagerContext()->PinSmallOpsToCPU());
588
589 run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
590 }
591
~ContextInterface()592 ContextInterface::~ContextInterface() {}
593
GetChain()594 AsyncValueRef<Chain>* ContextInterface::GetChain() {
595 auto thread_id = std::this_thread::get_id();
596 {
597 tensorflow::tf_shared_lock l(chain_map_mu_);
598 auto it = thread_local_chain_.find(thread_id);
599 if (it != thread_local_chain_.end()) {
600 return &it->second;
601 }
602 }
603 {
604 tensorflow::mutex_lock l(chain_map_mu_);
605 if (thread_local_chain_.find(thread_id) == thread_local_chain_.end()) {
606 auto chain = GetReadyChain();
607 thread_local_chain_[thread_id] = std::move(chain);
608 }
609 return &thread_local_chain_[thread_id];
610 }
611 }
612
613 template <typename T>
MakeScalarTensor(T value,HostContext * host)614 static TensorInterface* MakeScalarTensor(T value, HostContext* host) {
615 // The TensorInterface implementation assumes the tensor is a DenseHostTensor,
616 // so we need to use a DenseHostTensor to represent a scalar tensor.
617 TensorMetadata md(GetDType<T>(), {});
618 auto t = DenseHostTensor::CreateUninitialized(md, host);
619 if (!t) {
620 LOG(ERROR) << "Failed to create DenseHostTensor";
621 return nullptr;
622 }
623 auto& dht = t.getValue();
624 MutableDHTArrayView<T> view{&dht};
625 view.Elements()[0] = value;
626
627 return new TensorInterface(
628 MakeAvailableAsyncValueRef<DenseHostTensor>(host, std::move(dht)));
629 }
630
CreateInt64Scalar(int64_t value)631 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt64Scalar(
632 int64_t value) {
633 return MakeScalarTensor(value, GetHostContext());
634 }
635
CreateUint64Scalar(uint64_t value)636 tensorflow::AbstractTensorInterface* ContextInterface::CreateUint64Scalar(
637 uint64_t value) {
638 return MakeScalarTensor(value, GetHostContext());
639 }
640
CreateInt32Scalar(int32_t value)641 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt32Scalar(
642 int32_t value) {
643 return MakeScalarTensor(value, GetHostContext());
644 }
645
CreateFloatScalar(float value)646 tensorflow::AbstractTensorInterface* ContextInterface::CreateFloatScalar(
647 float value) {
648 return MakeScalarTensor(value, GetHostContext());
649 }
650
CreateDoubleScalar(double value)651 tensorflow::AbstractTensorInterface* ContextInterface::CreateDoubleScalar(
652 double value) {
653 return MakeScalarTensor(value, GetHostContext());
654 }
655
CreateHalfScalar(Eigen::half value)656 tensorflow::AbstractTensorInterface* ContextInterface::CreateHalfScalar(
657 Eigen::half value) {
658 return MakeScalarTensor(value, GetHostContext());
659 }
660
CreateStringScalar(tensorflow::tstring value)661 tensorflow::AbstractTensorInterface* ContextInterface::CreateStringScalar(
662 tensorflow::tstring value) {
663 auto* host = GetHostContext();
664 TensorMetadata md(DType(DType::String), {});
665 auto t = StringHostTensor::MakeConstructedAsyncValueRef(md, host);
666 if (t.IsError()) {
667 LOG(ERROR) << "Failed to create StringHostTensor";
668 return nullptr;
669 }
670 t->strings()[0] = value;
671
672 t.SetStateConcrete();
673 return new TensorInterface(std::move(t));
674 }
675
CreateComplex128Scalar(tensorflow::complex128 value)676 tensorflow::AbstractTensorInterface* ContextInterface::CreateComplex128Scalar(
677 tensorflow::complex128 value) {
678 return MakeScalarTensor(value, GetHostContext());
679 }
680
CreateBoolScalar(bool value)681 tensorflow::AbstractTensorInterface* ContextInterface::CreateBoolScalar(
682 bool value) {
683 return MakeScalarTensor(value, GetHostContext());
684 }
685
CreateTensor(tensorflow::DataType dtype,absl::Span<const int64_t> dim_sizes)686 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
687 tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) {
688 std::vector<Index> dimvec(dim_sizes.size());
689 for (int i = 0; i < dim_sizes.size(); ++i) {
690 dimvec[i] = static_cast<int64_t>(dim_sizes[i]);
691 }
692
693 TensorMetadata md;
694 switch (dtype) {
695 case tensorflow::DT_UINT8:
696 md = CreateMetadata(DType::UI8, dimvec);
697 break;
698 case tensorflow::DT_INT8:
699 md = CreateMetadata(DType::I8, dimvec);
700 break;
701 case tensorflow::DT_INT16:
702 md = CreateMetadata(DType::I16, dimvec);
703 break;
704 case tensorflow::DT_INT32:
705 md = CreateMetadata(DType::I32, dimvec);
706 break;
707 case tensorflow::DT_INT64:
708 md = CreateMetadata(DType::I64, dimvec);
709 break;
710 case tensorflow::DT_HALF:
711 md = CreateMetadata(DType::F16, dimvec);
712 break;
713 case tensorflow::DT_FLOAT:
714 md = CreateMetadata(DType::F32, dimvec);
715 break;
716 case tensorflow::DT_DOUBLE:
717 md = CreateMetadata(DType::F64, dimvec);
718 break;
719 case tensorflow::DT_BOOL:
720 md = CreateMetadata(DType::I1, dimvec);
721 break;
722 case tensorflow::DT_COMPLEX64:
723 md = CreateMetadata(DType::Complex64, dimvec);
724 break;
725 case tensorflow::DT_COMPLEX128:
726 md = CreateMetadata(DType::Complex128, dimvec);
727 break;
728 case tensorflow::DT_VARIANT:
729 // Note: TF Python API can create variant tensor for ragged tensor.
730 md = CreateMetadata(DType::Variant, dimvec);
731 break;
732 case tensorflow::DT_STRING:
733 // No TFRT Metadata needed for non-scalar string tensors.
734 break;
735 default:
736 LOG(ERROR) << "Cannot create tensor with dtype: " << dtype;
737 return nullptr;
738 }
739
740 if (dtype == tensorflow::DT_STRING) {
741 // Create Tensorflow Tensor as a buffer for tstrings.
742 return new TensorInterface(
743 tensorflow::Tensor(dtype, tensorflow::TensorShape(dim_sizes)));
744 } else {
745 auto t = DenseHostTensor::CreateUninitialized(md, GetHostContext());
746 return new TensorInterface(MakeAvailableAsyncValueRef<DenseHostTensor>(
747 GetHostContext(), std::move(t.getValue())));
748 }
749 }
750
CreateTensor(tensorflow::DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)751 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
752 tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
753 size_t len, MemoryReleaser memory_releaser, void* memory_releaser_arg) {
754 TensorMetadata metadata(ConvertDType(dtype),
755 {dims, static_cast<size_t>(num_dims)});
756 RCReference<HostBuffer> buffer = HostBuffer::CreateFromExternal(
757 data, len,
758 [memory_releaser, memory_releaser_arg](void* data, size_t len) {
759 memory_releaser(data, len, memory_releaser_arg);
760 });
761 AsyncValueRef<DenseHostTensor> dht =
762 MakeConstructedAsyncValueRef<DenseHostTensor>(GetHostContext(), metadata,
763 std::move(buffer));
764
765 dht.SetStateConcrete();
766 return new TensorInterface(std::move(dht));
767 }
768
UsesTFRT()769 bool ContextInterface::UsesTFRT() { return true; }
770
CreateLocalHandle(tensorflow::AbstractTensorInterface * t)771 tensorflow::ImmediateExecutionTensorHandle* ContextInterface::CreateLocalHandle(
772 tensorflow::AbstractTensorInterface* t) {
773 auto* tensor_interface = down_cast<TensorInterface*>(t);
774 auto* host = GetHostContext();
775
776 // Create RuntimeFallbackTensor from a TF Tensor, and then create
777 // the according TensorHandleInterface.
778 if (tensor_interface->IsTfTensor()) {
779 tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
780 tensorflow::TensorHandle::CreateLocalHandle(
781 tensor_interface->TfTensor())};
782
783 auto expected_result_tensor =
784 tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
785 std::move(tf_tensor_handle), GetHostContext());
786
787 if (expected_result_tensor) {
788 return new TensorHandleInterface(
789 Value(TensorHandle(
790 host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
791 MakeAvailableAsyncValueRef<
792 tensorflow::tfd::RuntimeFallbackTensor>(
793 host, std::move(expected_result_tensor.get())))),
794 GetTfrtContext());
795 } else {
796 return new TensorHandleInterface(
797 Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
798 GetHostContext(), StrCat(expected_result_tensor.takeError())))),
799 GetTfrtContext());
800 }
801 }
802
803 auto tensor_av = tensor_interface->TensorRef();
804 const TensorMetadata& md = tensor_av.get<Tensor>().metadata();
805
806 // NOTE(fishx): Following logic is needed to let TF-TFRT fully reach
807 // performance parity with current TF. This API is used to by tf.constant
808 // to convert Python object to **CPU** Tensor. tf.constant in current TF
809 // heavily depends on Tensor Mirroring feature for good performance. However,
810 // TFRT does not have Tensor Mirroring feature. In order to use Tensor
811 // Mirroring from current TF runtime, we convert the result of tf.constant to
812 // Fallback Tensor.
813
814 if (tensor_av.IsAvailable()) {
815 if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
816 return new TensorHandleInterface(
817 Value(TensorHandle(
818 host->GetHostDeviceRef(), md,
819 MakeAvailableAsyncValueRef<
820 tensorflow::tfd::RuntimeFallbackTensor>(
821 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
822 *dht, host)))),
823 GetTfrtContext());
824 }
825 } else {
826 auto result_tensor = MakeIndirectAsyncValue(host);
827 tensor_av.AndThen([host, result_tensor = result_tensor,
828 tensor_av = tensor_av.CopyRef()]() {
829 if (auto* dht =
830 llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
831 result_tensor->ForwardTo(
832 MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
833 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
834 *dht, host)));
835 } else {
836 result_tensor->ForwardTo(tensor_av.CopyRef());
837 }
838 });
839 return new TensorHandleInterface(
840 Value(TensorHandle(host->GetHostDeviceRef(), md,
841 AsyncValueRef<Tensor>(std::move(result_tensor)))),
842 GetTfrtContext());
843 }
844 return new TensorHandleInterface(
845 Value(TensorHandle(host->GetHostDeviceRef(), md, std::move(tensor_av))),
846 GetTfrtContext());
847 }
848
849 tensorflow::ImmediateExecutionTensorHandle*
CreateLocalHandleFromTFTensor(tensorflow::Tensor & t,const char * d_name)850 ContextInterface::CreateLocalHandleFromTFTensor(tensorflow::Tensor& t,
851 const char* d_name) {
852 auto* host = GetHostContext();
853 // Create RuntimeFallbackTensor from a TF Tensor, and then create
854 // the according TensorHandleInterface.
855 tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
856 tensorflow::TensorHandle::CreateLocalHandle(std::move(t))};
857
858 tfrt::Expected<tensorflow::tfd::RuntimeFallbackTensor>
859 expected_result_tensor =
860 tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
861 std::move(tf_tensor_handle), GetHostContext());
862
863 if (expected_result_tensor) {
864 return new TensorHandleInterface(
865 Value(TensorHandle(
866 host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
867 MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
868 host, std::move(expected_result_tensor.get())))),
869 GetTfrtContext());
870 } else {
871 return new TensorHandleInterface(
872 Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
873 GetHostContext(), StrCat(expected_result_tensor.takeError())))),
874 GetTfrtContext());
875 }
876 }
877
878 tensorflow::ImmediateExecutionTensorHandle*
TFTensorHandleFromInterface(tensorflow::ImmediateExecutionTensorHandle * handle)879 ContextInterface::TFTensorHandleFromInterface(
880 tensorflow::ImmediateExecutionTensorHandle* handle) {
881 TensorHandle th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
882 AsyncValue* tensor_av = th.GetAsyncTensor();
883 if (tensor_av->IsUnavailable()) GetHostContext()->Await(FormRef(tensor_av));
884
885 auto& tensor = th.GetAsyncTensor()->get<Tensor>();
886
887 if (auto* rtfbt =
888 llvm::dyn_cast<tensorflow::tfd::RuntimeFallbackTensor>(&tensor))
889 return rtfbt->GetTensorHandle();
890
891 if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(&tensor)) {
892 return tensorflow::TensorHandle::CreateLocalHandle(
893 tensorflow::tfd::MoveHostBufferToTfTensor(dht->buffer(), dht->dtype(),
894 dht->shape()));
895 }
896
897 if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(&tensor)) {
898 return tensorflow::TensorHandle::CreateLocalHandle(
899 tensorflow::tfd::CopyShtToTfTensor(*sht));
900 }
901
902 LOG(ERROR) << "Unsupported tensor type";
903 return nullptr;
904 }
905
CreateOperation()906 tensorflow::ImmediateExecutionOperation* ContextInterface::CreateOperation() {
907 return new OperationInterface(this);
908 }
909
910 // TODO(srbs): Change this to directly fetch the MLIR function once that is
911 // supported.
RegisterFunction(tensorflow::AbstractFunction * f)912 tensorflow::Status ContextInterface::RegisterFunction(
913 tensorflow::AbstractFunction* f) {
914 tensorflow::FunctionDef* fdef;
915 TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
916 if (!fdef) {
917 return tensorflow::errors::InvalidArgument(
918 "GetFunctionDef returned nullptr.");
919 }
920 return AddFunctionDef(*fdef);
921 }
922
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)923 void ContextInterface::ListDevices(
924 std::vector<tensorflow::DeviceAttributes>* devices) {
925 context_.GetEagerContext()->ListDevices(devices);
926 }
927
AddDevices(std::vector<std::unique_ptr<tensorflow::Device>> devices)928 tensorflow::Status ContextInterface::AddDevices(
929 std::vector<std::unique_ptr<tensorflow::Device>> devices) {
930 if (!devices.empty() && devices[0]->device_type() != "CPU")
931 return tensorflow::errors::InvalidArgument(
932 "Device: ", devices[0]->device_type(), " is not allowed to be added ",
933 "after the context is initialized. Currently allowed device: CPU. ",
934 "May update this API to allow adding more types of devices.");
935
936 for (const auto& d : devices) {
937 GetHostContext()->GetDeviceManager()->MaybeAddDevice(
938 TakeRef(new CpuDevice(d->name())));
939 }
940 TF_RETURN_IF_ERROR(GetEagerContext()->AddDevices(std::move(devices)));
941
942 return ::tensorflow::OkStatus();
943 }
944
ClearCachesAndThreadExecutors()945 void ContextInterface::ClearCachesAndThreadExecutors() {
946 GetEagerContext()->ClearCachesAndThreadExecutors();
947 GetHostContext()->Quiesce();
948 }
949
StartStep()950 void ContextInterface::StartStep() { GetEagerContext()->StartStep(); }
951
EndStep()952 void ContextInterface::EndStep() { GetEagerContext()->EndStep(); }
953
EnableCollectiveOps(const tensorflow::ServerDef & server_def)954 tensorflow::Status ContextInterface::EnableCollectiveOps(
955 const tensorflow::ServerDef& server_def) {
956 if (use_tfrt_distributed_runtime_) {
957 return distributed_manager_->EnableCollectiveOps(server_def);
958 }
959 // Preserve the local virtual device names, since local virtual devices are
960 // added by TFRT and we need to add it back after worker server is
961 // initialized. Currently one such use case is the TPU_SYSTEM device, which
962 // is a virtual device specifically used to initialize TPUs.
963 std::vector<std::string> virtual_device_names;
964
965 for (const auto& d :
966 GetHostContext()->GetDeviceManager()->ListDevices<Device>()) {
967 if (d->IsDeviceType(tfrt::VirtualDevice::kDeviceType)) {
968 tensorflow::DeviceNameUtils::ParsedName p;
969 if (!tensorflow::DeviceNameUtils::ParseFullName(d->name().str(), &p)) {
970 return tensorflow::errors::InvalidArgument(
971 "Invalid local virtual device name: ", d->name().str());
972 }
973
974 virtual_device_names.push_back(tensorflow::DeviceNameUtils::FullName(
975 server_def.job_name(), /*replica=*/0, server_def.task_index(), p.type,
976 p.id));
977 }
978 }
979
980 TF_RETURN_IF_ERROR(GetEagerContext()->EnableCollectiveOps(server_def));
981
982 // Create new devices with updated device name.
983 std::vector<std::unique_ptr<tensorflow::Device>> dummy_tf_devices;
984 CreateDummyTfDevices(virtual_device_names, &dummy_tf_devices);
985
986 std::string name_prefix =
987 absl::StrCat("/job:", server_def.job_name(),
988 "/replica:0/task:", server_def.task_index());
989
990 // Update host device in TFRT HostContext.
991 GetHostContext()->ResetHostDevice(
992 GetHostContext()
993 ->GetDeviceManager()
994 ->MaybeAddDevice(TakeRef(
995 new CpuDevice(absl::StrCat(name_prefix, "/device:CPU:0"))))
996 .release());
997
998 // Update virtual devices in TFRT HostContext.
999 AddDummyTfrtDevices(virtual_device_names, GetHostContext());
1000
1001 // Update eager context's device manager.
1002 auto* local_device_mgr = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
1003 GetEagerContext()->local_device_mgr());
1004 TF_RETURN_IF_ERROR(local_device_mgr->AddDevices(std::move(dummy_tf_devices)));
1005
1006 return ::tensorflow::OkStatus();
1007 }
1008
BuildFunctionRequestContext(tensorflow::tfrt_stub::OpKernelRunnerTable * runner_table,RCReference<tfrt::RequestContext> * request_context)1009 tensorflow::Status ContextInterface::BuildFunctionRequestContext(
1010 tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
1011 RCReference<tfrt::RequestContext>* request_context) {
1012 auto* step_container = GetEagerContext()->StepContainer();
1013 RequestContextBuilder request_context_builder(
1014 GetHostContext(), GetResourceContext(), step_container->StepId());
1015
1016 TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
1017 &request_context_builder, runner_table, GetEagerContext()));
1018 if (distributed_manager_ != nullptr) {
1019 down_cast<DistributedManagerContextInterface*>(distributed_manager_.get())
1020 ->UpdateRequestContextBuilder(&request_context_builder);
1021 }
1022 auto expected_request_context = std::move(request_context_builder).build();
1023 if (!expected_request_context) {
1024 return tensorflow::errors::Internal(
1025 StrCat(expected_request_context.takeError()));
1026 }
1027 *request_context = std::move(expected_request_context.get());
1028 return ::tensorflow::OkStatus();
1029 }
1030
BuildOpRequestContext(RCReference<tfrt::RequestContext> * request_context)1031 tensorflow::Status ContextInterface::BuildOpRequestContext(
1032 RCReference<tfrt::RequestContext>* request_context) {
1033 return BuildFunctionRequestContext(/*runner_table=*/nullptr, request_context);
1034 }
1035
1036 tensorflow::ImmediateExecutionTensorHandle*
CopyTensorHandleToDevice(tensorflow::ImmediateExecutionTensorHandle * handle,const char * device_name,tensorflow::Status * status)1037 ContextInterface::CopyTensorHandleToDevice(
1038 tensorflow::ImmediateExecutionTensorHandle* handle, const char* device_name,
1039 tensorflow::Status* status) {
1040 auto* host_ctx = GetHostContext();
1041
1042 TensorHandle src_th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
1043
1044 auto tfrt_device_name =
1045 ConvertTfDeviceNameToTfrt(device_name, GetEagerContext());
1046 if (!tfrt_device_name) {
1047 *status = tensorflow::errors::InvalidArgument(
1048 StrCat(tfrt_device_name.takeError()));
1049 RCReference<AsyncValue> error_av =
1050 MakeErrorAsyncValueRef(host_ctx, status->error_message());
1051 return new TensorHandleInterface(
1052 Value(TensorHandle::CreateError(std::move(error_av))),
1053 GetTfrtContext());
1054 }
1055 auto dst_device_ref = host_ctx->GetDeviceManager()->GetDeviceRef<Device>(
1056 tfrt_device_name.get());
1057 if (!dst_device_ref) {
1058 std::string error_message =
1059 tfrt::StrCat("Failed to find destination device with name: ",
1060 tfrt_device_name.get());
1061 *status = tensorflow::errors::Internal(error_message);
1062 RCReference<AsyncValue> error_av =
1063 MakeErrorAsyncValueRef(host_ctx, error_message);
1064 return new TensorHandleInterface(
1065 Value(TensorHandle::CreateError(std::move(error_av))),
1066 GetTfrtContext());
1067 }
1068
1069 RCReference<RequestContext> request_ctx;
1070 *status = BuildOpRequestContext(&request_ctx);
1071 if (!status->ok()) return nullptr;
1072
1073 ExecutionContext exec_ctx{std::move(request_ctx)};
1074
1075 auto target_th =
1076 src_th.TransferToInferredType(exec_ctx, std::move(dst_device_ref));
1077
1078 auto target_av = target_th.GetAsyncTensor();
1079 if (target_av->IsError()) {
1080 *status = tensorflow::errors::Internal(
1081 tfrt::StrCat("Copying to device <", tfrt_device_name.get(),
1082 "> failed: ", target_av->GetError().message));
1083 return nullptr;
1084 }
1085 return new TensorHandleInterface(Value(target_th.CopyRef()),
1086 GetTfrtContext());
1087 }
1088
AddFunctionDef(const tensorflow::FunctionDef & fdef)1089 tensorflow::Status ContextInterface::AddFunctionDef(
1090 const tensorflow::FunctionDef& fdef) {
1091 return GetEagerContext()->AddFunctionDef(fdef);
1092 }
1093
AddFunctionDefWithStackTraces(const tensorflow::FunctionDef & fdef,const tensorflow::StackTracesMap & stack_traces)1094 tensorflow::Status ContextInterface::AddFunctionDefWithStackTraces(
1095 const tensorflow::FunctionDef& fdef,
1096 const tensorflow::StackTracesMap& stack_traces) {
1097 return GetEagerContext()->AddFunctionDefWithStackTraces(fdef, stack_traces);
1098 }
1099
ListFunctionNames()1100 std::vector<std::string> ContextInterface::ListFunctionNames() {
1101 return GetEagerContext()->ListFunctionNames();
1102 }
1103
RemoveFunction(const std::string & func)1104 tensorflow::Status ContextInterface::RemoveFunction(const std::string& func) {
1105 // TODO(tfrt-devs): We need to ensure all invocations of this function is
1106 // finished before removing it.
1107 function_cache_.RemoveFunction(func);
1108 return GetEagerContext()->RemoveFunction(func);
1109 }
1110
FindFunctionDef(const std::string & name) const1111 const tensorflow::FunctionDef* ContextInterface::FindFunctionDef(
1112 const std::string& name) const {
1113 return GetEagerContext()->FindFunctionDef(name);
1114 }
1115
1116 const tensorflow::DeviceNameUtils::ParsedName&
HostCPUParsedName() const1117 ContextInterface::HostCPUParsedName() const {
1118 return context_.HostCPUParsedName();
1119 }
1120
HostCPUName() const1121 const std::string& ContextInterface::HostCPUName() const {
1122 return context_.GetEagerContext()->HostCPUName();
1123 }
1124
1125 tensorflow::CustomDeviceOpHandler&
GetCustomDeviceOpHandler()1126 ContextInterface::GetCustomDeviceOpHandler() {
1127 return context_.GetEagerContext()->GetCustomDeviceOpHandler();
1128 }
1129
RegisterCustomDevice(const std::string & name,std::unique_ptr<tensorflow::CustomDevice> device)1130 tensorflow::Status ContextInterface::RegisterCustomDevice(
1131 const std::string& name, std::unique_ptr<tensorflow::CustomDevice> device) {
1132 return context_.GetEagerContext()->RegisterCustomDevice(name,
1133 std::move(device));
1134 }
1135
FuncLibDef()1136 tensorflow::FunctionLibraryDefinition* ContextInterface::FuncLibDef() {
1137 return context_.GetEagerContext()->FuncLibDef();
1138 }
1139
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)1140 void ContextInterface::SetReuseRendezvousForFunctions(
1141 bool reuse_rendezvous_for_functions) {
1142 // TODO(fishx): This feature doesn't work properly in TFRT yet. Fix it.
1143 context_.GetEagerContext()->SetReuseRendezvousForFunctions(
1144 reuse_rendezvous_for_functions);
1145 }
1146
ResetGlobalRendezvousForFunction()1147 void ContextInterface::ResetGlobalRendezvousForFunction() {
1148 context_.GetEagerContext()->ResetGlobalRendezvousForFunction();
1149 }
1150
GetLoggedOpsTestonly()1151 std::vector<std::string> ContextInterface::GetLoggedOpsTestonly() {
1152 const auto& ret = GetHostContext()
1153 ->GetOrCreateSharedContext<tensorflow::tfd::OpLogger>()
1154 .GetLoggedOps();
1155 return std::vector<std::string>(ret.begin(), ret.end());
1156 }
1157
GetHostContext()1158 HostContext* ContextInterface::GetHostContext() {
1159 return GetCoreRuntime()->GetHostContext();
1160 }
1161
GetEagerContext()1162 tensorflow::EagerContext* ContextInterface::GetEagerContext() {
1163 return context_.GetEagerContext();
1164 }
1165
GetEagerContext() const1166 const tensorflow::EagerContext* ContextInterface::GetEagerContext() const {
1167 return context_.GetEagerContext();
1168 }
1169
GetCoreRuntime()1170 CoreRuntime* ContextInterface::GetCoreRuntime() {
1171 return context_.GetCoreRuntime();
1172 }
1173
GetTfrtContext()1174 TfrtContext* ContextInterface::GetTfrtContext() { return &context_; }
1175
GetFallbackOpHandler()1176 OpHandler* ContextInterface::GetFallbackOpHandler() {
1177 return context_.GetFallbackOpHandler();
1178 }
1179
GetResourceContext()1180 ResourceContext* ContextInterface::GetResourceContext() {
1181 return context_.GetResourceContext();
1182 }
1183
SelectOpHandlerFromArguments(const tensorflow::ImmediateExecutionOperation & op,OpHandler ** op_handler)1184 tensorflow::Status ContextInterface::SelectOpHandlerFromArguments(
1185 const tensorflow::ImmediateExecutionOperation& op, OpHandler** op_handler) {
1186 return op_handler_selector_->SelectFromArguments(op, op_handler);
1187 }
1188
SelectOpHandlerFromNodeDef(const tensorflow::ImmediateExecutionOperation & op,const NodeDef * node_def,OpHandler ** op_handler)1189 tensorflow::Status ContextInterface::SelectOpHandlerFromNodeDef(
1190 const tensorflow::ImmediateExecutionOperation& op, const NodeDef* node_def,
1191 OpHandler** op_handler) {
1192 return op_handler_selector_->SelectFromNodeDef(op, node_def, op_handler);
1193 }
1194
ExportRunMetadata()1195 std::unique_ptr<tensorflow::RunMetadata> ContextInterface::ExportRunMetadata() {
1196 mutex_lock l(run_metadata_mu_);
1197
1198 // NOTE(fishx): We need to merge run_metadata from TF Eager Context because
1199 // right now we still use current TF runtime to execute graph (e.g. tf.data
1200 // via fallback).
1201 auto result = GetEagerContext()->ExportRunMetadata();
1202 result->MergeFrom(*run_metadata_);
1203 run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
1204
1205 return result;
1206 }
1207
RunMetadataRecordFunction(const std::string & func_name)1208 tensorflow::Status ContextInterface::RunMetadataRecordFunction(
1209 const std::string& func_name) {
1210 const tensorflow::FunctionDef* fdef =
1211 GetEagerContext()->FindFunctionDef(func_name);
1212 if (fdef == nullptr) {
1213 return tensorflow::errors::InvalidArgument(
1214 "Failed to find function \"", func_name, "\" in function library");
1215 }
1216 std::unique_ptr<tensorflow::FunctionBody> fbody;
1217 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1218 *fdef, tensorflow::AttrSlice(), GetEagerContext()->FuncLibDef(), &fbody));
1219 tensorflow::GraphDef def;
1220 fbody->graph->ToGraphDef(&def);
1221 *def.mutable_library() =
1222 GetEagerContext()->FuncLibDef()->ReachableDefinitions(def).ToProto();
1223
1224 mutex_lock l(run_metadata_mu_);
1225 auto* function_graphs = run_metadata_->add_function_graphs();
1226 *function_graphs->mutable_pre_optimization_graph() = def;
1227 // TODO(b/b/171600738): Figure out a way to record the right post optimization
1228 // graph and partition graph.
1229 *function_graphs->mutable_post_optimization_graph() = def;
1230 *function_graphs->add_partition_graphs() = def;
1231 *run_metadata_->add_partition_graphs() = def;
1232 return ::tensorflow::OkStatus();
1233 }
1234
SetExecutorForThread(tensorflow::EagerExecutor * executor)1235 void ContextInterface::SetExecutorForThread(
1236 tensorflow::EagerExecutor* executor) {
1237 GetEagerContext()->SetExecutorForThread(executor);
1238 }
1239
GetCurrentLocation()1240 tfrt::Location AbortLocationHandler::GetCurrentLocation() {
1241 return tfrt::Location(this, GetNextLocationId());
1242 }
1243
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const1244 void OpAttrsInterface::GetNameAttrList(
1245 tensorflow::NameAttrList* name_and_attrs) const {
1246 fallback_attrs_->FillAttrValueMap(name_and_attrs->mutable_attr());
1247 name_and_attrs->set_name(fallback_attrs_->op_name());
1248 }
1249
GetTypeList(absl::string_view attr_name,absl::InlinedVector<tensorflow::DataType,4> * type_list) const1250 Status OpAttrsInterface::GetTypeList(
1251 absl::string_view attr_name,
1252 absl::InlinedVector<tensorflow::DataType, 4>* type_list) const {
1253 return tensorflow::errors::Unimplemented("OpAttrsInterface::GetTypeList");
1254 }
1255
GetInt(absl::string_view attr_name,int64_t * result) const1256 bool OpAttrsInterface::GetInt(absl::string_view attr_name,
1257 int64_t* result) const {
1258 return attrs_->Get<int64_t>({attr_name.data(), attr_name.size()}, result);
1259 }
1260
GetFloat(absl::string_view attr_name,float * result) const1261 bool OpAttrsInterface::GetFloat(absl::string_view attr_name,
1262 float* result) const {
1263 return attrs_->Get<float>({attr_name.data(), attr_name.size()}, result);
1264 }
1265
GetBool(absl::string_view attr_name,bool * result) const1266 bool OpAttrsInterface::GetBool(absl::string_view attr_name,
1267 bool* result) const {
1268 return attrs_->Get<bool>({attr_name.data(), attr_name.size()}, result);
1269 }
1270
GetType(absl::string_view attr_name,tensorflow::DataType * result) const1271 bool OpAttrsInterface::GetType(absl::string_view attr_name,
1272 tensorflow::DataType* result) const {
1273 auto optional_type =
1274 attrs_->GetOptional<OpAttrType>({attr_name.data(), attr_name.size()});
1275 if (!optional_type.hasValue()) return false;
1276 *result = tensorflow::tfd::ConvertToTfDataType(optional_type.getValue());
1277 return true;
1278 }
1279
OperationInterface(ContextInterface * context)1280 OperationInterface::OperationInterface(ContextInterface* context)
1281 : ImmediateExecutionOperation(kTfrt),
1282 op_attrs_(&attrs_, &fallback_attrs_),
1283 context_(context) {}
1284
Reset(const char * op,const char * raw_device_name)1285 tensorflow::Status OperationInterface::Reset(const char* op,
1286 const char* raw_device_name) {
1287 op_name_ = op;
1288 args_.clear();
1289 attrs_.Reset();
1290 custom_device_tensor_handle_count_ = 0;
1291 op_def_ = nullptr;
1292 fallback_attrs_.Reset(op);
1293 stack_trace_.reset();
1294 op_ = nullptr;
1295 function_state_.reset();
1296 tensorflow::Status s = tensorflow::OpDefForOp(op_name_, &op_def_);
1297 is_function_ = !s.ok();
1298 return SetDeviceName(raw_device_name);
1299 }
1300
Execute(absl::Span<tensorflow::AbstractTensorHandle * > retvals,int * num_retvals)1301 tensorflow::Status OperationInterface::Execute(
1302 absl::Span<tensorflow::AbstractTensorHandle*> retvals, int* num_retvals) {
1303 tensorflow::profiler::TraceMe trace(
1304 [&] {
1305 return absl::StrCat("TFRT_Execute:", Name(), " device:", DeviceName());
1306 },
1307 tensorflow::profiler::TraceMeLevel::kInfo);
1308 if (custom_device_tensor_handle_count_ > 0) {
1309 return tensorflow::errors::InvalidArgument(
1310 "Cannot execute ops that conntains unsupported arg in TFRT.");
1311 }
1312
1313 TF_RETURN_IF_ERROR(Initialize());
1314 assert(op_ != nullptr || function_state_);
1315 auto* corert = context_->GetCoreRuntime();
1316 auto* chain = context_->GetChain();
1317 auto* host = corert->GetHostContext();
1318 llvm::SmallVector<TensorHandle, 8> th_args;
1319 th_args.reserve(args_.size());
1320
1321 llvm::SmallVector<TensorHandle, 8> result_ths;
1322 result_ths.resize(*num_retvals);
1323
1324 if (function_state_) {
1325 // Set up arguments. Check argument dtype synchronously if available.
1326 auto arg_types = function_state_->GetArgTypes();
1327 if (args_.size() != arg_types.size()) {
1328 return tensorflow::errors::InvalidArgument("Expects ", arg_types.size(),
1329 " arguments, but ",
1330 args_.size(), " is provided");
1331 }
1332 auto args_size = args_.size();
1333 for (auto i = 0; i < args_size; ++i) {
1334 th_args.push_back(down_cast<TensorHandleInterface*>(args_[i].get())
1335 ->Handle()
1336 .CopyRef());
1337 // TODO(b/173556766): This dtype check is only needed for corert lowering.
1338 // In native lowering, compiler should obtain the argument dtype
1339 // information from FunctionBody directly and lower the op to the native
1340 // kernel that accepts the specified dtype.
1341 if (th_args[i].IsMetadataAvailable()) {
1342 auto arg_dtype = th_args[i].GetAvailableMetadata().dtype;
1343 if (arg_dtype != arg_types[i]) {
1344 return tensorflow::errors::InvalidArgument(
1345 "Expects arg[", i, "] to be ", arg_types[i], " but ", arg_dtype,
1346 " is provided");
1347 }
1348 }
1349 }
1350
1351 RCReference<RequestContext> request_ctx;
1352 TF_RETURN_IF_ERROR(context_->BuildFunctionRequestContext(
1353 function_state_->GetRunnerTable(), &request_ctx));
1354
1355 ExecutionContext exec_ctx{std::move(request_ctx),
1356 abort_location_handler_.GetCurrentLocation()};
1357
1358 // Make BEF executor to use TfThreadPoolWorkQueue to dispatch kernels.
1359 exec_ctx.set_work_queue(
1360 context_->GetTfrtContext()->GetTfThreadPoolWorkQueue());
1361
1362 // Execute the function.
1363 function_state_->GetFunc()(exec_ctx, th_args, OpAttrsRef(attrs_),
1364 result_ths, chain);
1365 } else {
1366 RCReference<RequestContext> request_ctx;
1367 TF_RETURN_IF_ERROR(context_->BuildOpRequestContext(&request_ctx));
1368
1369 ExecutionContext exec_ctx{std::move(request_ctx),
1370 abort_location_handler_.GetCurrentLocation()};
1371 for (auto& arg : args_) {
1372 th_args.push_back(
1373 down_cast<TensorHandleInterface*>(arg.get())->Handle().CopyRef());
1374 }
1375 // If the CoreRuntimeOp is a native TFRT op, transfer arguments to target
1376 // device if necessary.
1377 if (!op_->IsFallback()) {
1378 // Get the target device of the arguments that we want to implicitly copy
1379 // to.
1380 auto dst_device_ref = op_->GetDeviceRef();
1381
1382 for (auto& th_arg : th_args) {
1383 th_arg =
1384 th_arg.TransferTo(exec_ctx, dst_device_ref, op_->GetTensorType());
1385 }
1386 }
1387
1388 (*op_)(exec_ctx, th_args, OpAttrsRef(attrs_), result_ths, chain);
1389 }
1390
1391 tensorflow::Status s = ::tensorflow::OkStatus();
1392
1393 if (TF_PREDICT_FALSE(!this->context_->IsAsync() && !chain->IsAvailable()))
1394 host->Await({chain->CopyRCRef()});
1395
1396 if (TF_PREDICT_FALSE(chain->IsError())) {
1397 s = CreateTfErrorStatus(chain->GetError());
1398 // TODO(tfrt-devs): Assess if we need a explicit API to clear error.
1399 *chain = GetReadyChain();
1400 }
1401
1402 for (size_t i = 0, e = result_ths.size(); i != e; ++i) {
1403 auto& th_ref = result_ths[i];
1404 if (TF_PREDICT_FALSE(!this->context_->IsAsync() &&
1405 !th_ref.GetAsyncTensor()->IsAvailable()))
1406 host->Await(FormRef(th_ref.GetAsyncTensor()));
1407
1408 // NOTE(fishx): In async mode, we won't report error synchronously even
1409 // though it is possible in TFRT. This is intended to match behavior in
1410 // current TF. However, in the future, we may want to update this
1411 // behavior since synchronous error may improve user experience in async
1412 // mode.
1413 if (TF_PREDICT_FALSE(!this->context_->IsAsync() &&
1414 th_ref.GetAsyncTensor()->IsError() && s.ok()))
1415 s = CreateTfErrorStatus(th_ref.GetAsyncTensor()->GetError());
1416
1417 if (function_state_ && context_->IsAsync()) {
1418 retvals[i] = new TensorHandleInterface(function_state_->GetRetTypes()[i],
1419 Value(std::move(result_ths[i])),
1420 context_->GetTfrtContext());
1421 } else {
1422 retvals[i] = new TensorHandleInterface(Value(std::move(result_ths[i])),
1423 context_->GetTfrtContext());
1424 }
1425 }
1426
1427 return s;
1428 }
1429
Initialize()1430 tensorflow::Status OperationInterface::Initialize() {
1431 CoreRuntime* corert = context_->GetCoreRuntime();
1432 if (!is_function_) {
1433 // Obtain input arguments' dtype attrs as part of the cache key.
1434 llvm::SmallVector<string_view, 4> dtypes;
1435 attrs_.IterateEntries([&](const OpAttrsRawEntry& entry) {
1436 if (entry.type == OpAttrType::DTYPE && !entry.IsArray())
1437 dtypes.push_back(
1438 GetNameString(*static_cast<const OpAttrType*>(entry.GetData())));
1439 });
1440
1441 OpHandler* op_handler = nullptr;
1442 TF_RETURN_IF_ERROR(
1443 context_->SelectOpHandlerFromArguments(*this, &op_handler));
1444 Expected<CoreRuntimeOp*> expected_op = context_->GetOpCache().GetOrAddOp(
1445 op_name_, op_handler, device_name_, dtypes, this);
1446 if (!expected_op) {
1447 return tensorflow::errors::InvalidArgument(
1448 StrCat("Cannot obtain CoreRuntimeOp: ", op_name_,
1449 " on device: ", device_name_, expected_op.takeError()));
1450 }
1451 op_ = expected_op.get();
1452 // Update device name since op_handler_selecter may choose an op_handler
1453 // that's different from what the user specifies.
1454 device_name_ = op_->DeviceName().str();
1455 return ::tensorflow::OkStatus();
1456 }
1457
1458 bool compile_with_xla = false;
1459 GetFuncAttr(attrs_, op_name_, *context_->GetEagerContext()->FuncLibDef(),
1460 tensorflow::kXlaMustCompileAttr, &compile_with_xla);
1461 // If the function has compile_with_xla==true, we will use RuntimeFallback
1462 // to execute it, since TFRT does not support xla yet.
1463 // TODO(tfrt-devs): Native support of compile_with_xla.
1464 if (compile_with_xla) {
1465 Expected<CoreRuntimeOp*> expected_op =
1466 context_->GetOpCache().GetOrAddXlaOp(op_name_, context_);
1467 if (!expected_op) {
1468 return tensorflow::errors::NotFound(
1469 StrCat("Cannot initialize xla function ", op_name_,
1470 " on fallback op handler.", expected_op.takeError()));
1471 }
1472 op_ = expected_op.get();
1473 return ::tensorflow::OkStatus();
1474 }
1475
1476 // Note(fishx): We need eager context for now because we need
1477 // FunctionLibraryDefinition to convert FunctionDef to MLIR TF dialect. In
1478 // the future, when we can generate MLIR from TF Python, we should get rid of
1479 // this.
1480 // FunctionDef -> BEF.
1481 // Look up the cache. Compile BEF and insert to cache if miss.
1482 tensorflow::DeviceSet dev_set;
1483 const DeviceMgr* device_mgr = context_->GetEagerContext()->local_device_mgr();
1484 if (device_mgr == nullptr)
1485 return tensorflow::errors::NotFound("Cannot find device manager");
1486 // TODO(tfrt-devs): support remote devices in TFRT.
1487 for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
1488 if (context_->GetDistributedManager() != nullptr &&
1489 context_->UseTfrtDistributedRuntime()) {
1490 down_cast<DistributedManagerContextInterface*>(
1491 context_->GetDistributedManager())
1492 ->PopulateRemoteDevices(&dev_set);
1493 }
1494 FunctionCache::FunctionCacheResult result;
1495
1496 tensorflow::TfrtFunctionCompileOptions compile_options;
1497
1498 // Use the host device if the user does not place the function to a specific
1499 // device.
1500 compile_options.default_device =
1501 device_name_.empty() ? context_->GetEagerContext()->HostCPUName()
1502 : device_name_;
1503
1504 // TODO(b/172659131): Do not use TFRT native ops for TF integration for now.
1505 // Re-enable once we have a concrete plan to implement feature complete
1506 // TFRT native ops (kernels).
1507 compile_options.enable_native_ops = false;
1508
1509 if (fallback_attrs_.NumAttributes() > 0) {
1510 const auto& ndef = NodeDef();
1511 // TODO(tfrt-devs): If we are to create more attributes, consider packing
1512 // them into a proto.
1513 {
1514 const auto& it = ndef.attr().find(kEnableNativeOpsAttr);
1515 if (it != ndef.attr().end()) {
1516 compile_options.enable_native_ops = it->second.b();
1517 }
1518 }
1519
1520 {
1521 const auto& it = ndef.attr().find(kEnableGrapplerAttr);
1522 if (it != ndef.attr().end()) {
1523 compile_options.enable_grappler = it->second.b();
1524 }
1525 }
1526 }
1527
1528 llvm::SmallVector<const tfrt::Device*, 4> input_devices;
1529 input_devices.reserve(args_.size());
1530 for (auto& arg : args_) {
1531 auto arg_th = down_cast<TensorHandleInterface*>(arg.get())->Handle();
1532 if (!arg_th.IsDeviceAvailable()) {
1533 corert->GetHostContext()->Await(arg_th.GetAsyncDevice().CopyRCRef());
1534 }
1535 input_devices.push_back(down_cast<TensorHandleInterface*>(arg.get())
1536 ->Handle()
1537 .GetAvailableDevice()
1538 .get());
1539 }
1540 TF_RETURN_IF_ERROR(context_->GetFunctionCache().GetOrAddFunction(
1541 op_name_, device_name_, dev_set, context_->GetEagerContext(), corert,
1542 /*request_ctx_fn=*/
1543 [this](tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
1544 RCReference<RequestContext>* request_ctx) {
1545 return context_->BuildFunctionRequestContext(runner_table, request_ctx);
1546 },
1547 abort_location_handler_.GetCurrentLocation(), compile_options,
1548 input_devices, &result));
1549 // TODO(tfrt-devs): Avoid calling EagerContext::ShouldStoreGraphs().
1550 if (result.is_cache_miss &&
1551 context_->GetEagerContext()->ShouldStoreGraphs()) {
1552 TF_RETURN_IF_ERROR(context_->RunMetadataRecordFunction(op_name_));
1553 }
1554 function_state_ = std::move(result.function_state);
1555 return ::tensorflow::OkStatus();
1556 }
1557
SetDeviceName(const char * name)1558 tensorflow::Status OperationInterface::SetDeviceName(const char* name) {
1559 if (op_ && name != device_name_) {
1560 return tensorflow::errors::Internal(
1561 "Failed to update device name. Right now TFRT cannot update device "
1562 "name of a fallback op if it is initialized.");
1563 }
1564 device_name_ = name ? name : "";
1565 return ::tensorflow::OkStatus();
1566 }
1567
AddInput(tensorflow::AbstractTensorHandle * input)1568 tensorflow::Status OperationInterface::AddInput(
1569 tensorflow::AbstractTensorHandle* input) {
1570 tensorflow::ImmediateExecutionTensorHandle* h =
1571 down_cast<tensorflow::ImmediateExecutionTensorHandle*>(input);
1572 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1573 if (tensorflow::CustomDeviceTensorHandle::classof(h)) {
1574 custom_device_tensor_handle_count_++;
1575 }
1576 h->Ref();
1577 args_.push_back(
1578 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1579 h));
1580 return ::tensorflow::OkStatus();
1581 }
1582
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)1583 tensorflow::Status OperationInterface::SetInput(
1584 size_t index, tensorflow::ImmediateExecutionTensorHandle* input) {
1585 if (index >= args_.size()) {
1586 return tensorflow::errors::InvalidArgument("Index >= inputs.size: %d >= %d",
1587 index, args_.size());
1588 }
1589 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1590 if (tensorflow::CustomDeviceTensorHandle::classof(args_[index].get())) {
1591 custom_device_tensor_handle_count_--;
1592 }
1593 if (tensorflow::CustomDeviceTensorHandle::classof(input)) {
1594 custom_device_tensor_handle_count_++;
1595 }
1596 input->Ref();
1597 args_[index] =
1598 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1599 input);
1600 return ::tensorflow::OkStatus();
1601 }
1602
AddInputList(absl::Span<tensorflow::AbstractTensorHandle * const> inputs)1603 tensorflow::Status OperationInterface::AddInputList(
1604 absl::Span<tensorflow::AbstractTensorHandle* const> inputs) {
1605 return tensorflow::errors::Unimplemented(
1606 "Unimplemented OperationInterface::AddInputList");
1607 }
1608
1609 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const>
GetInputs() const1610 OperationInterface::GetInputs() const {
1611 return absl::MakeSpan(
1612 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
1613 args_.data()),
1614 args_.size());
1615 }
1616
SetAttrString(const char * attr_name,const char * data,size_t length)1617 tensorflow::Status OperationInterface::SetAttrString(const char* attr_name,
1618 const char* data,
1619 size_t length) {
1620 fallback_attrs_.Set(attr_name, tensorflow::StringPiece(data, length));
1621 if (attrs_.SetString(attr_name, string_view(data, length)))
1622 return ::tensorflow::OkStatus();
1623 return tensorflow::errors::Internal(
1624 "OperationInterface::SetAttrString failed");
1625 }
1626
SetAttrInt(const char * attr_name,int64_t value)1627 tensorflow::Status OperationInterface::SetAttrInt(const char* attr_name,
1628 int64_t value) {
1629 fallback_attrs_.Set(attr_name, static_cast<int64_t>(value));
1630 if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1631 return tensorflow::errors::Internal("OperationInterface::SetAttrInt failed");
1632 }
1633
SetAttrFloat(const char * attr_name,float value)1634 tensorflow::Status OperationInterface::SetAttrFloat(const char* attr_name,
1635 float value) {
1636 fallback_attrs_.Set(attr_name, value);
1637 if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1638 return tensorflow::errors::Internal(
1639 "OperationInterface::SetAttrFloat failed");
1640 }
1641
SetAttrBool(const char * attr_name,bool value)1642 tensorflow::Status OperationInterface::SetAttrBool(const char* attr_name,
1643 bool value) {
1644 fallback_attrs_.Set(attr_name, value);
1645 if (attrs_.Set(attr_name, value)) return ::tensorflow::OkStatus();
1646 return tensorflow::errors::Internal("OperationInterface::SetAttrBool failed");
1647 }
1648
SetAttrType(const char * attr_name,tensorflow::DataType value)1649 tensorflow::Status OperationInterface::SetAttrType(const char* attr_name,
1650 tensorflow::DataType value) {
1651 fallback_attrs_.Set(attr_name, value);
1652 if (value == tensorflow::DT_INVALID) {
1653 return tensorflow::errors::InvalidArgument(
1654 "OperationInterface::SetAttrType failed to set DT_INVALID");
1655 }
1656 if (attrs_.Set(attr_name,
1657 tfrt::GetOpAttrTypeFromDType(
1658 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(value))))
1659 return ::tensorflow::OkStatus();
1660 // TODO(fishx): Remove this workaround once we support all dtype in TF.
1661 // This is fine for now since attribute "T", "U", "Tidx" is not used by TFRT
1662 // native ops.
1663 if (std::strcmp(attr_name, "T") == 0 || std::strcmp(attr_name, "U") == 0 ||
1664 std::strcmp(attr_name, "Tidx") == 0) {
1665 return ::tensorflow::OkStatus();
1666 }
1667 return tensorflow::errors::Internal("OperationInterface::SetAttrType failed");
1668 }
1669
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)1670 tensorflow::Status OperationInterface::SetAttrShape(const char* attr_name,
1671 const int64_t* dims,
1672 const int num_dims) {
1673 // NOTE: This is copied from EagerOperation::SetAttrShape.
1674 // TODO(b/154554118): Remove the duplication.
1675 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1676 return tensorflow::errors::InvalidArgument(
1677 "Value specified for `", attr_name, "` has ", num_dims,
1678 " dimensions which is over the limit of ",
1679 tensorflow::TensorShape::MaxDimensions(), ".");
1680 }
1681
1682 tensorflow::TensorShapeProto proto;
1683 size_t offset;
1684 if (num_dims < 0) {
1685 proto.set_unknown_rank(true);
1686
1687 // Set unranked ShapeAttr.
1688 offset = bef_attr_encoder_.EncodeUnrankedShapeAttr();
1689 } else {
1690 for (int d = 0; d < num_dims; ++d) {
1691 proto.add_dim()->set_size(dims[d]);
1692 }
1693
1694 // Set RankedShapeAttr.
1695 offset = bef_attr_encoder_.EncodeRankedShapeAttr(
1696 llvm::makeArrayRef(dims, num_dims));
1697 }
1698 fallback_attrs_.Set(attr_name, proto);
1699
1700 auto buf = bef_attr_encoder_.TakeResult();
1701 tfrt::ShapeAttr shape_attr(buf.data() + offset);
1702 // TODO(tfrt-devs): Avoid the copy.
1703 if (attrs_.Set(attr_name, shape_attr)) return ::tensorflow::OkStatus();
1704
1705 return tensorflow::errors::Internal(
1706 "OperationInterface::SetAttrShape failed");
1707 }
1708
SetAttrFunction(const char * attr_name,const tensorflow::AbstractOperation * value)1709 tensorflow::Status OperationInterface::SetAttrFunction(
1710 const char* attr_name, const tensorflow::AbstractOperation* value) {
1711 auto* value_operation = down_cast<const OperationInterface*>(value);
1712 // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1713 // Consider removing this and rely on TFRT OpAttrs.
1714 tensorflow::AttrValue attr_value;
1715 tensorflow::NameAttrList* func = attr_value.mutable_func();
1716 func->set_name(value->Name());
1717 fallback_attrs_.Set(attr_name, attr_value);
1718
1719 if (attrs_.SetFunc(attr_name, {string_view(value_operation->Name())}))
1720 return ::tensorflow::OkStatus();
1721
1722 return tensorflow::errors::Internal(
1723 "OperationInterface::SetAttrFunction failed");
1724 }
1725
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)1726 tensorflow::Status OperationInterface::SetAttrFunctionName(
1727 const char* attr_name, const char* data, size_t length) {
1728 // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1729 // Consider removing this and rely on TFRT OpAttrs.
1730 tensorflow::AttrValue attr_value;
1731 tensorflow::NameAttrList* func = attr_value.mutable_func();
1732 func->set_name(data);
1733 fallback_attrs_.Set(attr_name, attr_value);
1734
1735 if (attrs_.SetFunc(attr_name, {data})) return ::tensorflow::OkStatus();
1736
1737 return tensorflow::errors::Internal(
1738 "OperationInterface::SetAttrFunctionName failed");
1739 }
1740
SerializeTFETensorToDenseAttr(tensorflow::AbstractTensorInterface * tensor,tfrt::BefAttrEncoder * encoder)1741 static size_t SerializeTFETensorToDenseAttr(
1742 tensorflow::AbstractTensorInterface* tensor,
1743 tfrt::BefAttrEncoder* encoder) {
1744 std::vector<uint8_t> data;
1745
1746 const auto element_type =
1747 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(tensor->Type());
1748 llvm::SmallVector<int64_t, 4> shape;
1749 for (int i = 0; i < tensor->NumDims(); ++i) {
1750 shape.push_back(tensor->Dim(i));
1751 }
1752 auto elements = llvm::makeArrayRef(
1753 reinterpret_cast<const uint8_t*>(tensor->Data()), tensor->ByteSize());
1754 return encoder->EncodeDenseAttr(static_cast<DType>(element_type), shape,
1755 elements);
1756 }
1757
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)1758 tensorflow::Status OperationInterface::SetAttrTensor(
1759 const char* attr_name, tensorflow::AbstractTensorInterface* tensor) {
1760 tfrt::BefAttrEncoder encoder;
1761 const size_t offset = SerializeTFETensorToDenseAttr(tensor, &encoder);
1762 auto buffer = encoder.TakeResult();
1763 DenseAttr dense_attr(buffer.data() + offset);
1764 if (attrs_.Set(attr_name, dense_attr)) return ::tensorflow::OkStatus();
1765
1766 return tensorflow::errors::Internal(
1767 "OperationInterface::SetAttrTensor failed");
1768 }
1769
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1770 tensorflow::Status OperationInterface::SetAttrStringList(
1771 const char* attr_name, const void* const* values, const size_t* lengths,
1772 int num_values) {
1773 std::vector<tensorflow::StringPiece> v(num_values);
1774 for (int i = 0; i < num_values; ++i) {
1775 v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1776 lengths[i]);
1777 }
1778 fallback_attrs_.Set(attr_name, v);
1779
1780 tfrt::BefAttrEncoder encoder;
1781 const size_t offset =
1782 encoder.EncodeStringListAttr(values, lengths, num_values);
1783 auto buf = encoder.TakeResult();
1784 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1785 // TODO(tfrt-devs): Avoid the copy.
1786 if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1787
1788 return tensorflow::errors::Internal(
1789 "OperationInterface::SetAttrStringList failed");
1790 }
1791
SetAttrFloatList(const char * attr_name,const float * values,int num_values)1792 tensorflow::Status OperationInterface::SetAttrFloatList(const char* attr_name,
1793 const float* values,
1794 int num_values) {
1795 fallback_attrs_.Set(
1796 attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1797
1798 if (attrs_.SetArray(attr_name, tfrt::ArrayRef<float>(values, num_values)))
1799 return ::tensorflow::OkStatus();
1800 return tensorflow::errors::Internal(
1801 "OperationInterface::SetAttrFloatList failed");
1802 }
1803
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)1804 tensorflow::Status OperationInterface::SetAttrIntList(const char* attr_name,
1805 const int64_t* values,
1806 int num_values) {
1807 fallback_attrs_.Set(
1808 attr_name, tensorflow::gtl::ArraySlice<const int64_t>(
1809 reinterpret_cast<const int64_t*>(values), num_values));
1810
1811 if (attrs_.SetArray(attr_name, tfrt::ArrayRef<int64_t>(values, num_values)))
1812 return ::tensorflow::OkStatus();
1813
1814 return tensorflow::errors::Internal(
1815 "OperationInterface::SetAttrIntList failed");
1816 }
1817
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)1818 tensorflow::Status OperationInterface::SetAttrTypeList(
1819 const char* attr_name, const tensorflow::DataType* values, int num_values) {
1820 fallback_attrs_.Set(attr_name,
1821 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1822 values, num_values));
1823 // Convert to OpAttrType first.
1824 llvm::SmallVector<tfrt::DType, 4> tfrt_dtypes;
1825 tfrt_dtypes.reserve(num_values);
1826 for (int i = 0; i < num_values; ++i) {
1827 tfrt_dtypes.push_back(
1828 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(values[i]));
1829 }
1830
1831 if (attrs_.SetRaw(attr_name, tfrt_dtypes.data(), tfrt::OpAttrType::DTYPE,
1832 num_values, OpAttrsRawEntryType::kArray))
1833 return ::tensorflow::OkStatus();
1834
1835 return tensorflow::errors::Internal(
1836 "OperationInterface::SetAttrTypeList failed");
1837 }
1838
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)1839 tensorflow::Status OperationInterface::SetAttrBoolList(
1840 const char* attr_name, const unsigned char* values, int num_values) {
1841 std::unique_ptr<bool[]> b(new bool[num_values]);
1842 for (int i = 0; i < num_values; ++i) {
1843 b[i] = values[i];
1844 }
1845 fallback_attrs_.Set(
1846 attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1847
1848 // Convert to bool first.
1849 llvm::SmallVector<bool, 4> bool_array;
1850 bool_array.reserve(num_values);
1851 for (int i = 0; i < num_values; ++i) {
1852 bool_array.push_back(static_cast<bool>((values[i])));
1853 }
1854 if (attrs_.SetArray(attr_name,
1855 tfrt::ArrayRef<bool>(bool_array.data(), num_values)))
1856 return ::tensorflow::OkStatus();
1857
1858 return tensorflow::errors::Internal(
1859 "OperationInterface::SetAttrBoolList failed");
1860 }
1861
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)1862 tensorflow::Status OperationInterface::SetAttrShapeList(const char* attr_name,
1863 const int64_t** dims,
1864 const int* num_dims,
1865 int num_values) {
1866 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1867 new tensorflow::TensorShapeProto[num_values]);
1868 for (int i = 0; i < num_values; ++i) {
1869 const auto num_dims_i = num_dims[i];
1870
1871 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1872 return tensorflow::errors::InvalidArgument(
1873 StrCat("Value specified for `", attr_name, "` has ", num_dims_i,
1874 " dimensions which is over the limit of ",
1875 tensorflow::TensorShape::MaxDimensions(), "."));
1876 }
1877 if (num_dims_i < 0) {
1878 proto[i].set_unknown_rank(true);
1879 } else {
1880 const int64_t* dims_i = dims[i];
1881 auto proto_i = &proto[i];
1882 for (int d = 0; d < num_dims_i; ++d) {
1883 proto_i->add_dim()->set_size(dims_i[d]);
1884 }
1885 }
1886 }
1887 fallback_attrs_.Set(attr_name,
1888 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1889 proto.get(), num_values));
1890
1891 BefAttrEncoder encoder;
1892 const size_t offset = encoder.EncodeShapeListAttr(dims, num_dims, num_values);
1893 auto buf = encoder.TakeResult();
1894 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1895 if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1896
1897 return tensorflow::errors::Internal(
1898 "OperationInterface::SetAttrShapeList failed");
1899 }
1900
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)1901 tensorflow::Status OperationInterface::SetAttrFunctionList(
1902 const char* attr_name, absl::Span<const AbstractOperation*> values) {
1903 size_t num_values = values.size();
1904 std::vector<const void*> func_attrs(num_values);
1905 std::vector<size_t> lengths(num_values);
1906
1907 for (int i = 0; i < num_values; ++i) {
1908 auto* value_operation = down_cast<const OperationInterface*>(values[i]);
1909 lengths[i] = value_operation->Name().length();
1910 func_attrs[i] = value_operation->Name().c_str();
1911 }
1912
1913 // Encode the array of function attributes with BEF typed attribute encoder to
1914 // an aggregated attribute.
1915 BefAttrEncoder encoder;
1916 const size_t offset =
1917 encoder.EncodeFuncListAttr(func_attrs.data(), lengths.data(), num_values);
1918 auto buf = encoder.TakeResult();
1919 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1920 if (attrs_.Set(attr_name, aggr_attr)) return ::tensorflow::OkStatus();
1921
1922 return tensorflow::errors::Internal(
1923 "OperationInterface::SetAttrFunctionList failed");
1924 }
1925
InputLength(const char * input_name,int * length)1926 tensorflow::Status OperationInterface::InputLength(const char* input_name,
1927 int* length) {
1928 return tensorflow::errors::Unimplemented(
1929 "Unimplemented OperationInterface::InputLength");
1930 }
1931
OutputLength(const char * output_name,int * length)1932 tensorflow::Status OperationInterface::OutputLength(const char* output_name,
1933 int* length) {
1934 return tensorflow::errors::Unimplemented(
1935 "Unimplemented OperationInterface::OutputLength");
1936 }
1937
GetOpAttrs() const1938 const tensorflow::AbstractOpAttrs* OperationInterface::GetOpAttrs() const {
1939 return &op_attrs_;
1940 }
1941
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)1942 void OperationInterface::AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) {
1943 auto* tfrt_op_attrs = down_cast<const OpAttrsInterface*>(op_attrs);
1944 tfrt_op_attrs->GetAttrs()->IterateEntries(
1945 [this](const OpAttrsRawEntry& entry) {
1946 attrs_.SetRaw(entry.name, entry.GetData(), entry.type,
1947 entry.element_count, entry.entry_type);
1948 });
1949 fallback_attrs_.CopyAttributes(*tfrt_op_attrs->GetFallbackAttrs());
1950 }
1951
MaybeInferInputAttrs()1952 void OperationInterface::MaybeInferInputAttrs() {
1953 if (!op_def_) return;
1954 for (int i = 0; i < args_.size(); i++) {
1955 auto* handle = args_[i].get();
1956 const auto& input_def = op_def_->input_arg(i);
1957 if (!input_def.number_attr().empty() ||
1958 !input_def.type_list_attr().empty()) {
1959 // Some clients that are still setting their input attributes manually are
1960 // adding input list to their op by calling `TFE_OpAddInput` for each of
1961 // its elements instead of calling `TFE_OpAddInputList`. When this
1962 // happens, we cannot detect the end of such list, thus lose track of the
1963 // input arguments in the op definition. To guarantee backward
1964 // compatibility with those clients, disable automatic inference in this
1965 // case.
1966 return;
1967 }
1968 const std::string& type_attr = input_def.type_attr();
1969 if (!type_attr.empty()) {
1970 bool success = attrs_.Set(
1971 type_attr, tfrt::GetOpAttrTypeFromDType(
1972 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(
1973 handle->DataType())));
1974 if (success) {
1975 fallback_attrs_.Set(type_attr, handle->DataType());
1976 }
1977 }
1978 }
1979 }
1980
1981 } // namespace tf
1982 } // namespace tfrt
1983