1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
17
18 #include <cstddef>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/immediate_execution_context.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tf_tensor_internal.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
30 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/casts.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/stringpiece.h"
48 #include "tensorflow/core/protobuf/error_codes.pb.h"
49 #include "tensorflow/core/public/session_options.h"
50 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
51 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
52 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
53 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
54 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
55 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
56 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_interface.h"
57 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_registry.h"
58 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
59 #include "tensorflow/core/tfrt/eager/virtual_device.h"
60 #include "tensorflow/core/tfrt/utils/error_util.h"
61 #include "tensorflow/core/tfrt/utils/utils.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63 #include "tfrt/common/compat/eigen/eigen_dtype.h" // from @tf_runtime
64 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
65 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
66 #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime
67 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
68 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
69 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
70 #include "tfrt/dtype/dtype.h" // from @tf_runtime
71 #include "tfrt/host_context/async_value.h" // from @tf_runtime
72 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
73 #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
74 #include "tfrt/host_context/chain.h" // from @tf_runtime
75 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
76 #include "tfrt/host_context/device.h" // from @tf_runtime
77 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
78 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
79 #include "tfrt/host_context/function.h" // from @tf_runtime
80 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
81 #include "tfrt/host_context/host_context.h" // from @tf_runtime
82 #include "tfrt/host_context/location.h" // from @tf_runtime
83 #include "tfrt/host_context/resource_context.h" // from @tf_runtime
84 #include "tfrt/metrics/common_metrics.h" // from @tf_runtime
85 #include "tfrt/support/error_util.h" // from @tf_runtime
86 #include "tfrt/support/forward_decls.h" // from @tf_runtime
87 #include "tfrt/support/ref_count.h" // from @tf_runtime
88 #include "tfrt/support/string_util.h" // from @tf_runtime
89 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
90 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
91 #include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime
92 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
93 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
94 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
95 #include "tfrt/tensor/tensor_serialize_utils.h" // from @tf_runtime
96 #include "tfrt/tensor/tensor_type_registration.h" // from @tf_runtime
97
98 namespace tfrt {
99 namespace tf {
100
101 namespace {
102
103 using tensorflow::down_cast;
104
105 constexpr char kGpuDeviceName[] = "GPU";
106 constexpr char kEnableNativeOpsAttr[] = "TFRT_TEST_enable_native_ops";
107 constexpr char kEnableGrapplerAttr[] = "TFRT_TEST_enable_grappler";
108
CreateMetadata(DType dtype,absl::Span<const ssize_t> dim_sizes)109 TensorMetadata CreateMetadata(DType dtype,
110 absl::Span<const ssize_t> dim_sizes) {
111 return TensorMetadata(DType(dtype),
112 TensorShape(llvm::ArrayRef<ssize_t>(
113 reinterpret_cast<const ssize_t*>(dim_sizes.data()),
114 dim_sizes.size())));
115 }
116
ConvertDType(DType kind)117 tensorflow::DataType ConvertDType(DType kind) {
118 switch (kind) {
119 case DType::UI8:
120 return tensorflow::DT_UINT8;
121 case DType::UI16:
122 return tensorflow::DT_UINT16;
123 case DType::UI32:
124 return tensorflow::DT_UINT32;
125 case DType::UI64:
126 return tensorflow::DT_UINT64;
127 case DType::I8:
128 return tensorflow::DT_INT8;
129 case DType::I16:
130 return tensorflow::DT_INT16;
131 case DType::I32:
132 return tensorflow::DT_INT32;
133 case DType::I64:
134 return tensorflow::DT_INT64;
135 case DType::BF16:
136 return tensorflow::DT_BFLOAT16;
137 case DType::F16:
138 return tensorflow::DT_HALF;
139 case DType::F32:
140 return tensorflow::DT_FLOAT;
141 case DType::F64:
142 return tensorflow::DT_DOUBLE;
143 case DType::I1:
144 return tensorflow::DT_BOOL;
145 case DType::Complex64:
146 return tensorflow::DT_COMPLEX64;
147 case DType::Complex128:
148 return tensorflow::DT_COMPLEX128;
149 case DType::String:
150 return tensorflow::DT_STRING;
151 case DType::Resource:
152 return tensorflow::DT_RESOURCE;
153 case DType::Variant:
154 return tensorflow::DT_VARIANT;
155 case DType::QUI8:
156 return tensorflow::DT_QUINT8;
157 case DType::QUI16:
158 return tensorflow::DT_QUINT16;
159 case DType::QI8:
160 return tensorflow::DT_QINT8;
161 case DType::QI16:
162 return tensorflow::DT_QINT16;
163 case DType::QI32:
164 return tensorflow::DT_QINT32;
165 default:
166 LOG(ERROR) << "Unsupported kind " << kind;
167 return tensorflow::DT_INVALID;
168 }
169 }
170
ConvertDType(tensorflow::DataType dtype)171 DType ConvertDType(tensorflow::DataType dtype) {
172 switch (dtype) {
173 case tensorflow::DT_UINT8:
174 return static_cast<DType>(DType::UI8);
175 case tensorflow::DT_UINT16:
176 return static_cast<DType>(DType::UI16);
177 case tensorflow::DT_UINT32:
178 return static_cast<DType>(DType::UI32);
179 case tensorflow::DT_UINT64:
180 return static_cast<DType>(DType::UI64);
181 case tensorflow::DT_INT8:
182 return static_cast<DType>(DType::I8);
183 case tensorflow::DT_INT16:
184 return static_cast<DType>(DType::I16);
185 case tensorflow::DT_INT32:
186 return static_cast<DType>(DType::I32);
187 case tensorflow::DT_INT64:
188 return static_cast<DType>(DType::I64);
189 case tensorflow::DT_BFLOAT16:
190 return static_cast<DType>(DType::BF16);
191 case tensorflow::DT_HALF:
192 return static_cast<DType>(DType::F16);
193 case tensorflow::DT_FLOAT:
194 return static_cast<DType>(DType::F32);
195 case tensorflow::DT_DOUBLE:
196 return static_cast<DType>(DType::F64);
197 case tensorflow::DT_BOOL:
198 return static_cast<DType>(DType::I1);
199 case tensorflow::DT_STRING:
200 return static_cast<DType>(DType::String);
201 case tensorflow::DT_COMPLEX64:
202 return static_cast<DType>(DType::Complex64);
203 case tensorflow::DT_COMPLEX128:
204 return static_cast<DType>(DType::Complex128);
205 case tensorflow::DT_RESOURCE:
206 return static_cast<DType>(DType::Resource);
207 case tensorflow::DT_VARIANT:
208 return static_cast<DType>(DType::Variant);
209 case tensorflow::DT_QUINT8:
210 return static_cast<DType>(DType::QUI8);
211 case tensorflow::DT_QUINT16:
212 return static_cast<DType>(DType::QUI16);
213 case tensorflow::DT_QINT8:
214 return static_cast<DType>(DType::QI8);
215 case tensorflow::DT_QINT16:
216 return static_cast<DType>(DType::QI16);
217 case tensorflow::DT_QINT32:
218 return static_cast<DType>(DType::QI32);
219 default:
220 LOG(FATAL) << "Unsupported dtype " << dtype;
221 }
222 }
223
ConvertDTypeToOpAttrType(tensorflow::DataType dtype)224 OpAttrType ConvertDTypeToOpAttrType(tensorflow::DataType dtype) {
225 switch (dtype) {
226 case tensorflow::DT_UINT8:
227 return OpAttrType::UI8;
228 case tensorflow::DT_UINT16:
229 return OpAttrType::UI16;
230 case tensorflow::DT_UINT32:
231 return OpAttrType::UI32;
232 case tensorflow::DT_UINT64:
233 return OpAttrType::UI64;
234 case tensorflow::DT_INT8:
235 return OpAttrType::I8;
236 case tensorflow::DT_INT16:
237 return OpAttrType::I16;
238 case tensorflow::DT_INT32:
239 return OpAttrType::I32;
240 case tensorflow::DT_INT64:
241 return OpAttrType::I64;
242 case tensorflow::DT_BFLOAT16:
243 return OpAttrType::BF16;
244 case tensorflow::DT_HALF:
245 return OpAttrType::F16;
246 case tensorflow::DT_FLOAT:
247 return OpAttrType::F32;
248 case tensorflow::DT_DOUBLE:
249 return OpAttrType::F64;
250 case tensorflow::DT_BOOL:
251 return OpAttrType::BOOL;
252 case tensorflow::DT_COMPLEX64:
253 return OpAttrType::COMPLEX64;
254 case tensorflow::DT_COMPLEX128:
255 return OpAttrType::COMPLEX128;
256 default:
257 LOG(FATAL) << "Unsupported dtype " << dtype;
258 }
259 }
260
261 // This method will first look at the calling op attrs and then look at the
262 // function def attrs to find the attribute value.
GetFuncAttr(const OpAttrs & op_attrs,const std::string & op_name,const tensorflow::FunctionLibraryDefinition & func_lib_def,string_view attr_name,bool * value)263 void GetFuncAttr(const OpAttrs& op_attrs, const std::string& op_name,
264 const tensorflow::FunctionLibraryDefinition& func_lib_def,
265 string_view attr_name, bool* value) {
266 bool success = op_attrs.Get(attr_name, value);
267 if (success) {
268 DVLOG(2) << "Caller explicitly specifies " << attr_name.str()
269 << (value ? "=true " : "=false, ");
270 return;
271 }
272
273 const tensorflow::FunctionDef* function_def = func_lib_def.Find(op_name);
274 if (function_def == nullptr) {
275 return;
276 }
277
278 tensorflow::Status status =
279 GetNodeAttr(tensorflow::AttrSlice(&function_def->attr()),
280 {attr_name.data(), attr_name.size()}, value);
281 if (status.ok()) {
282 DVLOG(2) << "Function definition explicitly specifies " << attr_name.str()
283 << (value ? "=true" : "=false");
284 return;
285 }
286 }
287
GetNextLocationId()288 int64_t GetNextLocationId() {
289 static std::atomic<int64_t> id(0);
290 return id.fetch_add(1, std::memory_order_relaxed);
291 }
292 } // namespace
293
Type() const294 tensorflow::DataType TensorInterface::Type() const {
295 auto kind = tensor_.get().metadata().dtype;
296 if (kind == DType::Unsupported) {
297 assert(llvm::isa<tensorflow::tfd::RuntimeFallbackTensor>(tensor_.get()));
298 return tensor_.get<tensorflow::tfd::RuntimeFallbackTensor>()
299 .GetTensorHandle()
300 ->DataType();
301 }
302 return ConvertDType(kind);
303 }
304
NumDims() const305 int TensorInterface::NumDims() const { return tensor_.get().shape().GetRank(); }
306
Dim(int dim_index) const307 int64_t TensorInterface::Dim(int dim_index) const {
308 return tensor_.get().shape().GetDimensionSize(dim_index);
309 }
310
NumElements() const311 int64_t TensorInterface::NumElements() const {
312 if (!tensor_) {
313 return static_cast<int64_t>(tf_tensor_.NumElements());
314 }
315 return tensor_.get().shape().GetNumElements();
316 }
317
ByteSize() const318 size_t TensorInterface::ByteSize() const {
319 return tensor_.get().metadata().GetHostSizeInBytes();
320 }
321
Data() const322 void* TensorInterface::Data() const {
323 if (!tensor_) {
324 return tensorflow::TensorCApi::Buffer(tf_tensor_)->data();
325 } else {
326 auto& tensor = tensor_.get<DenseHostTensor>();
327 return tensor.data();
328 }
329 }
330
331 // TFRT DenseHostTensor is always aligned
IsAligned() const332 bool TensorInterface::IsAligned() const { return true; }
333
CanMove() const334 bool TensorInterface::CanMove() const {
335 // It is safe to move the Tensor if and only if we own the unique reference to
336 // the tensor buffer.
337 auto& dht = tensor_.get<DenseHostTensor>();
338 return tensor_.IsUnique() && dht.buffer()->IsUnique();
339 }
340
SummarizeValue() const341 std::string TensorInterface::SummarizeValue() const {
342 if (!tensor_) {
343 return tf_tensor_.SummarizeValue(/*max_entries=*/3, /*print_v2=*/true);
344 } else {
345 std::string result;
346 llvm::raw_string_ostream result_ostream(result);
347 tensor_->Print(result_ostream);
348 return result;
349 }
350 }
351
TensorRef() const352 AsyncValueRef<Tensor> TensorInterface::TensorRef() const {
353 return tensor_.CopyRef();
354 }
355
TensorHandleInterface(Value && v,CoreRuntime * corert)356 TensorHandleInterface::TensorHandleInterface(Value&& v, CoreRuntime* corert)
357 : ImmediateExecutionTensorHandle(kTfrt),
358 corert_(*corert),
359 value_(std::move(v)) {}
360
DataType() const361 tensorflow::DataType TensorHandleInterface::DataType() const {
362 auto metadata = Metadata();
363 if (!metadata.hasValue()) {
364 LOG(ERROR)
365 << "Failed to get DataType due to error metadata: "
366 << value_.get<TensorHandle>().GetAsyncMetadata().GetError().message;
367 return tensorflow::DT_INVALID;
368 }
369 auto kind = metadata.getValue()->dtype;
370 if (kind == DType::Unsupported) {
371 AsyncValue* async_tensor = value_.get<TensorHandle>().GetAsyncTensor();
372 if (!async_tensor->IsAvailable()) {
373 corert_.GetHostContext()->Await(FormRef(async_tensor));
374 }
375
376 if (async_tensor->IsError()) {
377 LOG(ERROR) << "Failed to get DataType from an error tensor "
378 << async_tensor->GetError().message;
379 return tensorflow::DT_INVALID;
380 }
381 assert(async_tensor->IsType<tensorflow::tfd::RuntimeFallbackTensor>());
382 return async_tensor->get<tensorflow::tfd::RuntimeFallbackTensor>()
383 .GetTensorHandle()
384 ->DataType();
385 }
386 return ConvertDType(kind);
387 }
388
Shape(tensorflow::PartialTensorShape * shape) const389 tensorflow::Status TensorHandleInterface::Shape(
390 tensorflow::PartialTensorShape* shape) const {
391 auto metadata = Metadata();
392 if (!metadata.hasValue()) {
393 return CreateTfErrorStatus(
394 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
395 }
396 int num_dims = metadata.getValue()->shape.GetRank();
397 if (num_dims == -1) {
398 return tensorflow::Status::OK();
399 }
400 SmallVector<ssize_t, 8> dims;
401 metadata.getValue()->shape.GetDimensions(&dims);
402 TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
403 return tensorflow::Status::OK();
404 }
405
NumDims(int * num_dims) const406 tensorflow::Status TensorHandleInterface::NumDims(int* num_dims) const {
407 auto metadata = Metadata();
408 if (!metadata.hasValue()) {
409 return CreateTfErrorStatus(
410 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
411 }
412 *num_dims = metadata.getValue()->shape.GetRank();
413
414 return tensorflow::Status::OK();
415 }
416
NumElements(int64_t * num_elements) const417 tensorflow::Status TensorHandleInterface::NumElements(
418 int64_t* num_elements) const {
419 auto metadata = Metadata();
420 if (!metadata.hasValue()) {
421 return CreateTfErrorStatus(
422 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
423 }
424 *num_elements = metadata.getValue()->shape.GetNumElements();
425
426 return tensorflow::Status::OK();
427 }
428
Dim(int dim_index,int64_t * dim) const429 tensorflow::Status TensorHandleInterface::Dim(int dim_index,
430 int64_t* dim) const {
431 auto metadata = Metadata();
432 if (!metadata.hasValue()) {
433 return CreateTfErrorStatus(
434 value_.get<TensorHandle>().GetAsyncMetadata().GetError());
435 }
436 *dim = metadata.getValue()->shape.GetDimensionSize(dim_index);
437
438 return tensorflow::Status::OK();
439 }
440
DeviceName(tensorflow::Status * status) const441 const char* TensorHandleInterface::DeviceName(
442 tensorflow::Status* status) const {
443 auto& th = value_.get<TensorHandle>();
444 if (!th.IsDeviceAvailable()) {
445 corert_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
446 }
447 if (th.IsDeviceError()) {
448 *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
449 return nullptr;
450 }
451 return th.GetAvailableDevice()->name().data();
452 }
453
BackingDeviceName(tensorflow::Status * status) const454 const char* TensorHandleInterface::BackingDeviceName(
455 tensorflow::Status* status) const {
456 return DeviceName(status);
457 }
458
DeviceType(tensorflow::Status * status) const459 const char* TensorHandleInterface::DeviceType(
460 tensorflow::Status* status) const {
461 auto& th = value_.get<TensorHandle>();
462 if (!th.IsDeviceAvailable()) {
463 corert_.GetHostContext()->Await(th.GetAsyncDevice().CopyRCRef());
464 }
465 if (th.IsDeviceError()) {
466 *status = CreateTfErrorStatus(th.GetAsyncDevice().GetError());
467 return nullptr;
468 }
469 return th.GetAvailableDevice()->type().name().data();
470 }
471
Resolve(tensorflow::Status * status)472 tensorflow::AbstractTensorInterface* TensorHandleInterface::Resolve(
473 tensorflow::Status* status) {
474 auto* host_ctx = corert_.GetHostContext();
475 auto host_device_ref = host_ctx->GetHostDeviceRef();
476 auto& th = value_.get<TensorHandle>();
477
478 auto tensor_av = th.GetAsyncTensor();
479 if (!tensor_av->IsAvailable()) {
480 host_ctx->Await(FormRef(tensor_av));
481 }
482 if (auto* error = tensor_av->GetErrorIfPresent()) {
483 *status = CreateTfErrorStatus(*error);
484 return nullptr;
485 }
486 assert(th.IsMetadataAvailable());
487
488 if (th.GetAsyncTensor()->get<Tensor>().tensor_type() ==
489 StringHostTensor::kTensorType) {
490 tensorflow::Tensor tf_tensor =
491 tensorflow::tfd::CopyShtToTfTensor(tensor_av->get<StringHostTensor>());
492 return new tensorflow::TensorInterface(tf_tensor);
493 }
494
495 // Convert the tensor to DenseHostTensor.
496 auto req_ctx =
497 tfrt::RequestContextBuilder(host_ctx, /*resource_context=*/nullptr)
498 .build();
499 if (!req_ctx) {
500 *status = tensorflow::Status(
501 tensorflow::error::Code::UNKNOWN,
502 StrCat("Failed to build a RequestContext: ", req_ctx.takeError()));
503 return nullptr;
504 }
505 tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
506 auto target_th = th.TransferTo(exec_ctx, std::move(host_device_ref),
507 DenseHostTensor::kTensorType);
508
509 auto target_av = target_th.GetAsyncTensor();
510 if (!target_av->IsAvailable()) {
511 host_ctx->Await(FormRef(target_av));
512 }
513 if (target_av->IsError()) {
514 *status = tensorflow::Status(
515 tensorflow::error::Code::UNKNOWN,
516 StrCat("Cannot resolve tensor: ", target_av->GetError().message));
517 return nullptr;
518 }
519 auto host_tensor_ref = target_th.ReleaseTensorRef();
520 return new TensorInterface(std::move(host_tensor_ref));
521 }
522
Metadata() const523 llvm::Optional<const TensorMetadata*> TensorHandleInterface::Metadata() const {
524 auto& th = value_.get<TensorHandle>();
525 if (!th.IsMetadataAvailable()) {
526 corert_.GetHostContext()->Await(th.GetAsyncMetadata().CopyRCRef());
527 }
528 if (th.IsMetadataError()) {
529 return llvm::None;
530 }
531 return &th.GetAvailableMetadata();
532 }
533
ContextInterface(const tensorflow::SessionOptions & opts,tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,bool is_async,bool use_tfrt_distributed_runtime)534 ContextInterface::ContextInterface(
535 const tensorflow::SessionOptions& opts,
536 tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
537 bool is_async, bool use_tfrt_distributed_runtime)
538 : ImmediateExecutionContext(kTfrt),
539 context_(opts, default_device_placement_policy, is_async),
540 use_tfrt_distributed_runtime_(use_tfrt_distributed_runtime) {
541 LOG(INFO) << "TFRT Enabled";
542 metrics::AddTFRTVersionMetric();
543
544 op_handler_selector_ = std::make_unique<EagerOpHandlerSelector>(
545 GetCoreRuntime(), GetEagerContext(), GetFallbackOpHandler(),
546 GetEagerContext()->PinSmallOpsToCPU());
547
548 run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
549 }
550
~ContextInterface()551 ContextInterface::~ContextInterface() {}
552
GetChain()553 AsyncValueRef<Chain>* ContextInterface::GetChain() {
554 auto thread_id = std::this_thread::get_id();
555 {
556 tensorflow::tf_shared_lock l(chain_map_mu_);
557 auto it = thread_local_chain_.find(thread_id);
558 if (it != thread_local_chain_.end()) {
559 return &it->second;
560 }
561 }
562 {
563 tensorflow::mutex_lock l(chain_map_mu_);
564 if (thread_local_chain_.find(thread_id) == thread_local_chain_.end()) {
565 auto chain = GetReadyChain(GetHostContext());
566 thread_local_chain_[thread_id] = std::move(chain);
567 }
568 return &thread_local_chain_[thread_id];
569 }
570 }
571
572 template <typename T>
MakeScalarTensor(T value,HostContext * host)573 static TensorInterface* MakeScalarTensor(T value, HostContext* host) {
574 // The TensorInterface implementation assumes the tensor is a DenseHostTensor,
575 // so we need to use a DenseHostTensor to represent a scalar tensor.
576 TensorMetadata md(GetDType<T>(), {});
577 auto t = DenseHostTensor::CreateUninitialized(md, host);
578 if (!t) {
579 LOG(ERROR) << "Failed to create DenseHostTensor";
580 return nullptr;
581 }
582 auto& dht = t.getValue();
583 MutableDHTArrayView<T> view{&dht};
584 view.Elements()[0] = value;
585
586 return new TensorInterface(
587 MakeAvailableAsyncValueRef<DenseHostTensor>(host, std::move(dht)));
588 }
589
CreateInt64Scalar(int64_t value)590 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt64Scalar(
591 int64_t value) {
592 return MakeScalarTensor(value, GetHostContext());
593 }
594
CreateUint64Scalar(uint64_t value)595 tensorflow::AbstractTensorInterface* ContextInterface::CreateUint64Scalar(
596 uint64_t value) {
597 return MakeScalarTensor(value, GetHostContext());
598 }
599
CreateInt32Scalar(int32_t value)600 tensorflow::AbstractTensorInterface* ContextInterface::CreateInt32Scalar(
601 int32_t value) {
602 return MakeScalarTensor(value, GetHostContext());
603 }
604
CreateFloatScalar(float value)605 tensorflow::AbstractTensorInterface* ContextInterface::CreateFloatScalar(
606 float value) {
607 return MakeScalarTensor(value, GetHostContext());
608 }
609
CreateDoubleScalar(double value)610 tensorflow::AbstractTensorInterface* ContextInterface::CreateDoubleScalar(
611 double value) {
612 return MakeScalarTensor(value, GetHostContext());
613 }
614
CreateHalfScalar(Eigen::half value)615 tensorflow::AbstractTensorInterface* ContextInterface::CreateHalfScalar(
616 Eigen::half value) {
617 return MakeScalarTensor(value, GetHostContext());
618 }
619
CreateStringScalar(tensorflow::tstring value)620 tensorflow::AbstractTensorInterface* ContextInterface::CreateStringScalar(
621 tensorflow::tstring value) {
622 auto* host = GetHostContext();
623 TensorMetadata md(DType(DType::String), {});
624 auto t = StringHostTensor::MakeConstructedAsyncValueRef(md, host);
625 if (t.IsError()) {
626 LOG(ERROR) << "Failed to create StringHostTensor";
627 return nullptr;
628 }
629 t->strings()[0] = value;
630
631 t.SetStateConcrete();
632 return new TensorInterface(std::move(t));
633 }
634
CreateComplex128Scalar(tensorflow::complex128 value)635 tensorflow::AbstractTensorInterface* ContextInterface::CreateComplex128Scalar(
636 tensorflow::complex128 value) {
637 return MakeScalarTensor(value, GetHostContext());
638 }
639
CreateBoolScalar(bool value)640 tensorflow::AbstractTensorInterface* ContextInterface::CreateBoolScalar(
641 bool value) {
642 return MakeScalarTensor(value, GetHostContext());
643 }
644
CreateTensor(tensorflow::DataType dtype,absl::Span<const int64_t> dim_sizes)645 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
646 tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) {
647 std::vector<ssize_t> dimvec(dim_sizes.size());
648 for (int i = 0; i < dim_sizes.size(); ++i) {
649 dimvec[i] = static_cast<int64_t>(dim_sizes[i]);
650 }
651
652 TensorMetadata md;
653 switch (dtype) {
654 case tensorflow::DT_UINT8:
655 md = CreateMetadata(DType::UI8, dimvec);
656 break;
657 case tensorflow::DT_INT8:
658 md = CreateMetadata(DType::I8, dimvec);
659 break;
660 case tensorflow::DT_INT16:
661 md = CreateMetadata(DType::I16, dimvec);
662 break;
663 case tensorflow::DT_INT32:
664 md = CreateMetadata(DType::I32, dimvec);
665 break;
666 case tensorflow::DT_INT64:
667 md = CreateMetadata(DType::I64, dimvec);
668 break;
669 case tensorflow::DT_HALF:
670 md = CreateMetadata(DType::F16, dimvec);
671 break;
672 case tensorflow::DT_FLOAT:
673 md = CreateMetadata(DType::F32, dimvec);
674 break;
675 case tensorflow::DT_DOUBLE:
676 md = CreateMetadata(DType::F64, dimvec);
677 break;
678 case tensorflow::DT_BOOL:
679 md = CreateMetadata(DType::I1, dimvec);
680 break;
681 case tensorflow::DT_COMPLEX64:
682 md = CreateMetadata(DType::Complex64, dimvec);
683 break;
684 case tensorflow::DT_COMPLEX128:
685 md = CreateMetadata(DType::Complex128, dimvec);
686 break;
687 case tensorflow::DT_VARIANT:
688 // Note: TF Python API can create variant tensor for ragged tensor.
689 md = CreateMetadata(DType::Variant, dimvec);
690 break;
691 case tensorflow::DT_STRING:
692 // No TFRT Metadata needed for non-scalar string tensors.
693 break;
694 default:
695 LOG(ERROR) << "Cannot create tensor with dtype: " << dtype;
696 return nullptr;
697 }
698
699 if (dtype == tensorflow::DT_STRING) {
700 // Create Tensorflow Tensor as a buffer for tstrings.
701 return new TensorInterface(
702 tensorflow::Tensor(dtype, tensorflow::TensorShape(dim_sizes)));
703 } else {
704 auto t = DenseHostTensor::CreateUninitialized(md, GetHostContext());
705 return new TensorInterface(MakeAvailableAsyncValueRef<DenseHostTensor>(
706 GetHostContext(), std::move(t.getValue())));
707 }
708 }
709
CreateTensor(tensorflow::DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)710 tensorflow::AbstractTensorInterface* ContextInterface::CreateTensor(
711 tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
712 size_t len, MemoryReleaser memory_releaser, void* memory_releaser_arg) {
713 TensorMetadata metadata(ConvertDType(dtype),
714 {dims, static_cast<size_t>(num_dims)});
715 RCReference<HostBuffer> buffer = HostBuffer::CreateFromExternal(
716 data, len,
717 [memory_releaser, memory_releaser_arg](void* data, size_t len) {
718 memory_releaser(data, len, memory_releaser_arg);
719 });
720 AsyncValueRef<DenseHostTensor> dht =
721 MakeConstructedAsyncValueRef<DenseHostTensor>(GetHostContext(), metadata,
722 std::move(buffer));
723
724 dht.SetStateConcrete();
725 return new TensorInterface(std::move(dht));
726 }
727
UsesTFRT()728 bool ContextInterface::UsesTFRT() { return true; }
729
CreateLocalHandle(tensorflow::AbstractTensorInterface * t)730 tensorflow::ImmediateExecutionTensorHandle* ContextInterface::CreateLocalHandle(
731 tensorflow::AbstractTensorInterface* t) {
732 auto* tensor_interface = down_cast<TensorInterface*>(t);
733 auto* host = GetHostContext();
734
735 // Create RuntimeFallbackTensor from a TF Tensor, and then create
736 // the according TensorHandleInterface.
737 if (tensor_interface->IsTfTensor()) {
738 tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
739 tensorflow::TensorHandle::CreateLocalHandle(
740 tensor_interface->TfTensor())};
741
742 auto expected_result_tensor =
743 tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
744 std::move(tf_tensor_handle), GetHostContext());
745
746 if (expected_result_tensor) {
747 return new TensorHandleInterface(
748 Value(TensorHandle(
749 host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
750 MakeAvailableAsyncValueRef<
751 tensorflow::tfd::RuntimeFallbackTensor>(
752 host, std::move(expected_result_tensor.get())))),
753 GetCoreRuntime());
754 } else {
755 return new TensorHandleInterface(
756 Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
757 GetHostContext(), StrCat(expected_result_tensor.takeError())))),
758 GetCoreRuntime());
759 }
760 }
761
762 auto tensor_av = tensor_interface->TensorRef();
763 const TensorMetadata& md = tensor_av.get<Tensor>().metadata();
764
765 // NOTE(fishx): Following logic is needed to let TF-TFRT fully reach
766 // performance parity with current TF. This API is used to by tf.constant
767 // to convert Python object to **CPU** Tensor. tf.constant in current TF
768 // heavily depends on Tensor Mirroring feature for good performance. However,
769 // TFRT does not have Tensor Mirroring feature. In order to use Tensor
770 // Mirroring from current TF runtime, we convert the result of tf.constant to
771 // Fallback Tensor.
772
773 if (tensor_av.IsAvailable()) {
774 if (auto* dht = llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
775 return new TensorHandleInterface(
776 Value(TensorHandle(
777 host->GetHostDeviceRef(), md,
778 MakeAvailableAsyncValueRef<
779 tensorflow::tfd::RuntimeFallbackTensor>(
780 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
781 *dht, host)))),
782 GetCoreRuntime());
783 }
784 } else {
785 auto result_tensor = MakeIndirectAsyncValue(host);
786 tensor_av.AndThen([host, result_tensor = result_tensor.CopyRef(),
787 tensor_av = tensor_av.CopyRef()]() {
788 if (auto* dht =
789 llvm::dyn_cast<DenseHostTensor>(&tensor_av.get<Tensor>())) {
790 result_tensor->ForwardTo(
791 MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
792 host, tensorflow::tfd::CopyRefDHTToRuntimeFallbackTensor(
793 *dht, host)));
794 } else {
795 result_tensor->ForwardTo(tensor_av.CopyRef());
796 }
797 });
798 return new TensorHandleInterface(
799 Value(TensorHandle(host->GetHostDeviceRef(), md,
800 AsyncValueRef<Tensor>(std::move(result_tensor)))),
801 GetCoreRuntime());
802 }
803 return new TensorHandleInterface(
804 Value(TensorHandle(host->GetHostDeviceRef(), md, std::move(tensor_av))),
805 GetCoreRuntime());
806 }
807
808 tensorflow::ImmediateExecutionTensorHandle*
CreateLocalHandleFromTFTensor(tensorflow::Tensor & t,const char * d_name)809 ContextInterface::CreateLocalHandleFromTFTensor(tensorflow::Tensor& t,
810 const char* d_name) {
811 auto* host = GetHostContext();
812 // Create RuntimeFallbackTensor from a TF Tensor, and then create
813 // the according TensorHandleInterface.
814 tensorflow::tfd::OwnedTensorHandle tf_tensor_handle{
815 tensorflow::TensorHandle::CreateLocalHandle(std::move(t))};
816
817 tfrt::Expected<tensorflow::tfd::RuntimeFallbackTensor>
818 expected_result_tensor =
819 tensorflow::tfd::CreateRuntimeFallbackTensorFromTfTensorHandle(
820 std::move(tf_tensor_handle), GetHostContext());
821
822 if (expected_result_tensor) {
823 return new TensorHandleInterface(
824 Value(TensorHandle(
825 host->GetHostDeviceRef(), expected_result_tensor.get().metadata(),
826 MakeAvailableAsyncValueRef<tensorflow::tfd::RuntimeFallbackTensor>(
827 host, std::move(expected_result_tensor.get())))),
828 GetCoreRuntime());
829 } else {
830 return new TensorHandleInterface(
831 Value(TensorHandle::CreateError(MakeErrorAsyncValueRef(
832 GetHostContext(), StrCat(expected_result_tensor.takeError())))),
833 GetCoreRuntime());
834 }
835 }
836
837 tensorflow::ImmediateExecutionTensorHandle*
TFTensorHandleFromInterface(tensorflow::ImmediateExecutionTensorHandle * handle)838 ContextInterface::TFTensorHandleFromInterface(
839 tensorflow::ImmediateExecutionTensorHandle* handle) {
840 TensorHandle th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
841 AsyncValue* tensor_av = th.GetAsyncTensor();
842 if (tensor_av->IsUnavailable()) GetHostContext()->Await(FormRef(tensor_av));
843
844 auto& tensor = th.GetAsyncTensor()->get<Tensor>();
845
846 if (auto* rtfbt =
847 llvm::dyn_cast<tensorflow::tfd::RuntimeFallbackTensor>(&tensor))
848 return rtfbt->GetTensorHandle();
849
850 if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(&tensor)) {
851 return tensorflow::TensorHandle::CreateLocalHandle(
852 tensorflow::tfd::MoveHostBufferToTfTensor(dht->buffer().CopyRef(),
853 dht->dtype(), dht->shape()));
854 }
855
856 if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(&tensor)) {
857 return tensorflow::TensorHandle::CreateLocalHandle(
858 tensorflow::tfd::CopyShtToTfTensor(*sht));
859 }
860
861 LOG(ERROR) << "Unsupported tensor type";
862 return nullptr;
863 }
864
CreateOperation()865 tensorflow::ImmediateExecutionOperation* ContextInterface::CreateOperation() {
866 return new OperationInterface(this);
867 }
868
869 // TODO(srbs): Change this to directly fetch the MLIR function once that is
870 // supported.
RegisterFunction(tensorflow::AbstractFunction * f)871 tensorflow::Status ContextInterface::RegisterFunction(
872 tensorflow::AbstractFunction* f) {
873 tensorflow::FunctionDef* fdef;
874 TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
875 if (!fdef) {
876 return tensorflow::errors::InvalidArgument(
877 "GetFunctionDef returned nullptr.");
878 }
879 return AddFunctionDef(*fdef);
880 }
881
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)882 void ContextInterface::ListDevices(
883 std::vector<tensorflow::DeviceAttributes>* devices) {
884 context_.GetEagerContext()->ListDevices(devices);
885 }
886
AddDevices(std::vector<std::unique_ptr<tensorflow::Device>> devices)887 tensorflow::Status ContextInterface::AddDevices(
888 std::vector<std::unique_ptr<tensorflow::Device>> devices) {
889 if (!devices.empty() && devices[0]->device_type() != "CPU")
890 return tensorflow::errors::InvalidArgument(
891 "Device: ", devices[0]->device_type(), " is not allowed to be added ",
892 "after the context is initialized. Currently allowed device: CPU. ",
893 "May update this API to allow adding more types of devices.");
894
895 for (const auto& d : devices) {
896 GetHostContext()->GetDeviceManager()->MaybeAddDevice(
897 TakeRef(new CpuDevice(d->name())));
898 }
899 TF_RETURN_IF_ERROR(GetEagerContext()->AddDevices(std::move(devices)));
900
901 return tensorflow::Status::OK();
902 }
903
ClearCachesAndThreadExecutors()904 void ContextInterface::ClearCachesAndThreadExecutors() {
905 GetEagerContext()->ClearCachesAndThreadExecutors();
906 GetHostContext()->Quiesce();
907 }
908
StartStep()909 void ContextInterface::StartStep() { GetEagerContext()->StartStep(); }
910
EndStep()911 void ContextInterface::EndStep() { GetEagerContext()->EndStep(); }
912
EnableCollectiveOps(const tensorflow::ServerDef & server_def)913 tensorflow::Status ContextInterface::EnableCollectiveOps(
914 const tensorflow::ServerDef& server_def) {
915 if (use_tfrt_distributed_runtime_) {
916 return distributed_manager_->EnableCollectiveOps(server_def);
917 }
918 // Preserve the local virtual device names, since local virtual devices are
919 // added by TFRT and we need to add it back after worker server is
920 // initialized. Currently one such use case is the TPU_SYSTEM device, which
921 // is a virtual device specifically used to initialize TPUs.
922 std::vector<std::string> virtual_device_names;
923
924 for (const auto& d :
925 GetHostContext()->GetDeviceManager()->ListDevices<Device>()) {
926 if (d->IsDeviceType(tfrt::VirtualDevice::kDeviceType)) {
927 tensorflow::DeviceNameUtils::ParsedName p;
928 if (!tensorflow::DeviceNameUtils::ParseFullName(d->name().str(), &p)) {
929 return tensorflow::errors::InvalidArgument(
930 "Invalid local virtual device name: ", d->name().str());
931 }
932
933 virtual_device_names.push_back(tensorflow::DeviceNameUtils::FullName(
934 server_def.job_name(), /*replica=*/0, server_def.task_index(), p.type,
935 p.id));
936 }
937 }
938
939 TF_RETURN_IF_ERROR(GetEagerContext()->EnableCollectiveOps(server_def));
940
941 // Create new devices with updated device name.
942 std::vector<std::unique_ptr<tensorflow::Device>> dummy_tf_devices;
943 CreateDummyTfDevices(virtual_device_names, &dummy_tf_devices);
944
945 std::string name_prefix =
946 absl::StrCat("/job:", server_def.job_name(),
947 "/replica:0/task:", server_def.task_index());
948
949 // Update host device in TFRT HostContext.
950 GetHostContext()->ResetHostDevice(
951 GetHostContext()
952 ->GetDeviceManager()
953 ->MaybeAddDevice(TakeRef(
954 new CpuDevice(absl::StrCat(name_prefix, "/device:CPU:0"))))
955 .release());
956
957 // Update virtual devices in TFRT HostContext.
958 AddDummyTfrtDevices(virtual_device_names, GetHostContext());
959
960 // Update eager context's device manager.
961 auto* local_device_mgr = dynamic_cast<tensorflow::DynamicDeviceMgr*>(
962 GetEagerContext()->local_device_mgr());
963 TF_RETURN_IF_ERROR(local_device_mgr->AddDevices(std::move(dummy_tf_devices)));
964
965 return tensorflow::Status::OK();
966 }
967
BuildFunctionRequestContext(tensorflow::tfd::OpKernelRunnerTable * runner_table,RCReference<tfrt::RequestContext> * request_context)968 tensorflow::Status ContextInterface::BuildFunctionRequestContext(
969 tensorflow::tfd::OpKernelRunnerTable* runner_table,
970 RCReference<tfrt::RequestContext>* request_context) {
971 auto* step_container = GetEagerContext()->StepContainer();
972 RequestContextBuilder request_context_builder(
973 GetHostContext(), GetResourceContext(), step_container->StepId());
974
975 TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
976 &request_context_builder, runner_table, GetEagerContext()));
977 if (distributed_manager_ != nullptr) {
978 down_cast<DistributedManagerContextInterface*>(distributed_manager_.get())
979 ->UpdateRequestContextBuilder(&request_context_builder);
980 }
981 auto expected_request_context = std::move(request_context_builder).build();
982 if (!expected_request_context) {
983 return tensorflow::errors::Internal(
984 StrCat(expected_request_context.takeError()));
985 }
986 *request_context = std::move(expected_request_context.get());
987 return tensorflow::Status::OK();
988 }
989
BuildOpRequestContext(RCReference<tfrt::RequestContext> * request_context)990 tensorflow::Status ContextInterface::BuildOpRequestContext(
991 RCReference<tfrt::RequestContext>* request_context) {
992 return BuildFunctionRequestContext(/*runner_table=*/nullptr, request_context);
993 }
994
995 tensorflow::ImmediateExecutionTensorHandle*
CopyTensorHandleToDevice(tensorflow::ImmediateExecutionTensorHandle * handle,const char * device_name,tensorflow::Status * status)996 ContextInterface::CopyTensorHandleToDevice(
997 tensorflow::ImmediateExecutionTensorHandle* handle, const char* device_name,
998 tensorflow::Status* status) {
999 auto* host_ctx = GetHostContext();
1000
1001 TensorHandle src_th = tfrt::tf::TensorHandleFromInterface(handle)->Handle();
1002
1003 auto tfrt_device_name =
1004 ConvertTfDeviceNameToTfrt(device_name, GetEagerContext());
1005 if (!tfrt_device_name) {
1006 *status = tensorflow::errors::InvalidArgument(
1007 StrCat(tfrt_device_name.takeError()));
1008 RCReference<AsyncValue> error_av =
1009 MakeErrorAsyncValueRef(host_ctx, status->error_message());
1010 return new TensorHandleInterface(
1011 Value(TensorHandle::CreateError(std::move(error_av))),
1012 GetCoreRuntime());
1013 }
1014 auto dst_device_ref = host_ctx->GetDeviceManager()->GetDeviceRef<Device>(
1015 tfrt_device_name.get());
1016 if (!dst_device_ref) {
1017 std::string error_message =
1018 tfrt::StrCat("Failed to find destination device with name: ",
1019 tfrt_device_name.get());
1020 *status = tensorflow::errors::Internal(error_message);
1021 RCReference<AsyncValue> error_av =
1022 MakeErrorAsyncValueRef(host_ctx, error_message);
1023 return new TensorHandleInterface(
1024 Value(TensorHandle::CreateError(std::move(error_av))),
1025 GetCoreRuntime());
1026 }
1027
1028 RCReference<RequestContext> request_ctx;
1029 *status = BuildOpRequestContext(&request_ctx);
1030 if (!status->ok()) return nullptr;
1031
1032 ExecutionContext exec_ctx{std::move(request_ctx)};
1033
1034 auto target_th =
1035 src_th.TransferToInferredType(exec_ctx, std::move(dst_device_ref));
1036
1037 auto target_av = target_th.GetAsyncTensor();
1038 if (target_av->IsError()) {
1039 *status = tensorflow::errors::Internal(
1040 tfrt::StrCat("Copying to device <", tfrt_device_name.get(),
1041 "> failed: ", target_av->GetError().message));
1042 return nullptr;
1043 }
1044 return new TensorHandleInterface(Value(target_th.CopyRef()),
1045 GetCoreRuntime());
1046 }
1047
AddFunctionDef(const tensorflow::FunctionDef & fdef)1048 tensorflow::Status ContextInterface::AddFunctionDef(
1049 const tensorflow::FunctionDef& fdef) {
1050 return GetEagerContext()->AddFunctionDef(fdef);
1051 }
1052
AddFunctionDefWithStackTraces(const tensorflow::FunctionDef & fdef,const tensorflow::StackTracesMap & stack_traces)1053 tensorflow::Status ContextInterface::AddFunctionDefWithStackTraces(
1054 const tensorflow::FunctionDef& fdef,
1055 const tensorflow::StackTracesMap& stack_traces) {
1056 return GetEagerContext()->AddFunctionDefWithStackTraces(fdef, stack_traces);
1057 }
1058
ListFunctionNames()1059 std::vector<std::string> ContextInterface::ListFunctionNames() {
1060 return GetEagerContext()->ListFunctionNames();
1061 }
1062
RemoveFunction(const std::string & func)1063 tensorflow::Status ContextInterface::RemoveFunction(const std::string& func) {
1064 // TODO(tfrt-devs): We need to ensure all invocations of this function is
1065 // finished before removing it.
1066 function_cache_.RemoveFunction(func);
1067 return GetEagerContext()->RemoveFunction(func);
1068 }
1069
FindFunctionDef(const std::string & name) const1070 const tensorflow::FunctionDef* ContextInterface::FindFunctionDef(
1071 const std::string& name) const {
1072 return GetEagerContext()->FindFunctionDef(name);
1073 }
1074
1075 const tensorflow::DeviceNameUtils::ParsedName&
HostCPUParsedName() const1076 ContextInterface::HostCPUParsedName() const {
1077 return context_.HostCPUParsedName();
1078 }
1079
HostCPUName() const1080 const std::string& ContextInterface::HostCPUName() const {
1081 return context_.GetEagerContext()->HostCPUName();
1082 }
1083
1084 tensorflow::CustomDeviceOpHandler&
GetCustomDeviceOpHandler()1085 ContextInterface::GetCustomDeviceOpHandler() {
1086 return context_.GetEagerContext()->GetCustomDeviceOpHandler();
1087 }
1088
RegisterCustomDevice(const std::string & name,std::unique_ptr<tensorflow::CustomDevice> device)1089 tensorflow::Status ContextInterface::RegisterCustomDevice(
1090 const std::string& name, std::unique_ptr<tensorflow::CustomDevice> device) {
1091 return context_.GetEagerContext()->RegisterCustomDevice(name,
1092 std::move(device));
1093 }
1094
FuncLibDef()1095 tensorflow::FunctionLibraryDefinition* ContextInterface::FuncLibDef() {
1096 return context_.GetEagerContext()->FuncLibDef();
1097 }
1098
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)1099 void ContextInterface::SetReuseRendezvousForFunctions(
1100 bool reuse_rendezvous_for_functions) {
1101 // TODO(fishx): This feature doesn't work properly in TFRT yet. Fix it.
1102 context_.GetEagerContext()->SetReuseRendezvousForFunctions(
1103 reuse_rendezvous_for_functions);
1104 }
1105
ResetGlobalRendezvousForFunction()1106 void ContextInterface::ResetGlobalRendezvousForFunction() {
1107 context_.GetEagerContext()->ResetGlobalRendezvousForFunction();
1108 }
1109
GetLoggedOpsTestonly()1110 std::vector<std::string> ContextInterface::GetLoggedOpsTestonly() {
1111 const auto& ret = GetHostContext()
1112 ->GetOrCreateSharedContext<tensorflow::tfd::OpLogger>()
1113 .GetLoggedOps();
1114 return std::vector<std::string>(ret.begin(), ret.end());
1115 }
1116
GetHostContext()1117 HostContext* ContextInterface::GetHostContext() {
1118 return GetCoreRuntime()->GetHostContext();
1119 }
1120
GetEagerContext()1121 tensorflow::EagerContext* ContextInterface::GetEagerContext() {
1122 return context_.GetEagerContext();
1123 }
1124
GetEagerContext() const1125 const tensorflow::EagerContext* ContextInterface::GetEagerContext() const {
1126 return context_.GetEagerContext();
1127 }
1128
GetCoreRuntime()1129 CoreRuntime* ContextInterface::GetCoreRuntime() {
1130 return context_.GetCoreRuntime();
1131 }
1132
GetFallbackOpHandler()1133 OpHandler* ContextInterface::GetFallbackOpHandler() {
1134 return context_.GetFallbackOpHandler();
1135 }
1136
GetResourceContext()1137 ResourceContext* ContextInterface::GetResourceContext() {
1138 return context_.GetResourceContext();
1139 }
1140
SelectOpHandlerFromArguments(const tensorflow::ImmediateExecutionOperation & op,OpHandler ** op_handler)1141 tensorflow::Status ContextInterface::SelectOpHandlerFromArguments(
1142 const tensorflow::ImmediateExecutionOperation& op, OpHandler** op_handler) {
1143 return op_handler_selector_->SelectFromArguments(op, op_handler);
1144 }
1145
SelectOpHandlerFromNodeDef(const tensorflow::ImmediateExecutionOperation & op,const NodeDef * node_def,OpHandler ** op_handler)1146 tensorflow::Status ContextInterface::SelectOpHandlerFromNodeDef(
1147 const tensorflow::ImmediateExecutionOperation& op, const NodeDef* node_def,
1148 OpHandler** op_handler) {
1149 return op_handler_selector_->SelectFromNodeDef(op, node_def, op_handler);
1150 }
1151
ExportRunMetadata()1152 std::unique_ptr<tensorflow::RunMetadata> ContextInterface::ExportRunMetadata() {
1153 mutex_lock l(run_metadata_mu_);
1154
1155 // NOTE(fishx): We need to merge run_metadata from TF Eager Context because
1156 // right now we still use current TF runtime to execute graph (e.g. tf.data
1157 // via fallback).
1158 auto result = GetEagerContext()->ExportRunMetadata();
1159 result->MergeFrom(*run_metadata_);
1160 run_metadata_ = std::make_unique<tensorflow::RunMetadata>();
1161
1162 return result;
1163 }
1164
RunMetadataRecordFunction(const std::string & func_name)1165 tensorflow::Status ContextInterface::RunMetadataRecordFunction(
1166 const std::string& func_name) {
1167 const tensorflow::FunctionDef* fdef =
1168 GetEagerContext()->FindFunctionDef(func_name);
1169 if (fdef == nullptr) {
1170 return tensorflow::errors::InvalidArgument(
1171 "Failed to find function \"", func_name, "\" in function library");
1172 }
1173 std::unique_ptr<tensorflow::FunctionBody> fbody;
1174 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1175 *fdef, tensorflow::AttrSlice(), GetEagerContext()->FuncLibDef(), &fbody));
1176 tensorflow::GraphDef def;
1177 fbody->graph->ToGraphDef(&def);
1178 *def.mutable_library() =
1179 GetEagerContext()->FuncLibDef()->ReachableDefinitions(def).ToProto();
1180
1181 mutex_lock l(run_metadata_mu_);
1182 auto* function_graphs = run_metadata_->add_function_graphs();
1183 *function_graphs->mutable_pre_optimization_graph() = def;
1184 // TODO(b/b/171600738): Figure out a way to record the right post optimization
1185 // graph and partition graph.
1186 *function_graphs->mutable_post_optimization_graph() = def;
1187 *function_graphs->add_partition_graphs() = def;
1188 *run_metadata_->add_partition_graphs() = def;
1189 return tensorflow::Status::OK();
1190 }
1191
SetExecutorForThread(tensorflow::EagerExecutor * executor)1192 void ContextInterface::SetExecutorForThread(
1193 tensorflow::EagerExecutor* executor) {
1194 GetEagerContext()->SetExecutorForThread(executor);
1195 }
1196
GetCurrentLocation()1197 tfrt::Location AbortLocationHandler::GetCurrentLocation() {
1198 return tfrt::Location(this, GetNextLocationId());
1199 }
1200
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const1201 void OpAttrsInterface::GetNameAttrList(
1202 tensorflow::NameAttrList* name_and_attrs) const {
1203 fallback_attrs_->FillAttrValueMap(name_and_attrs->mutable_attr());
1204 name_and_attrs->set_name(fallback_attrs_->op_name());
1205 }
1206
GetTypeList(absl::string_view attr_name,absl::InlinedVector<tensorflow::DataType,4> * type_list) const1207 Status OpAttrsInterface::GetTypeList(
1208 absl::string_view attr_name,
1209 absl::InlinedVector<tensorflow::DataType, 4>* type_list) const {
1210 return tensorflow::errors::Unimplemented("OpAttrsInterface::GetTypeList");
1211 }
1212
GetInt(absl::string_view attr_name,int64_t * result) const1213 bool OpAttrsInterface::GetInt(absl::string_view attr_name,
1214 int64_t* result) const {
1215 return attrs_->Get<int64_t>({attr_name.data(), attr_name.size()}, result);
1216 }
1217
GetFloat(absl::string_view attr_name,float * result) const1218 bool OpAttrsInterface::GetFloat(absl::string_view attr_name,
1219 float* result) const {
1220 return attrs_->Get<float>({attr_name.data(), attr_name.size()}, result);
1221 }
1222
GetBool(absl::string_view attr_name,bool * result) const1223 bool OpAttrsInterface::GetBool(absl::string_view attr_name,
1224 bool* result) const {
1225 return attrs_->Get<bool>({attr_name.data(), attr_name.size()}, result);
1226 }
1227
GetType(absl::string_view attr_name,tensorflow::DataType * result) const1228 bool OpAttrsInterface::GetType(absl::string_view attr_name,
1229 tensorflow::DataType* result) const {
1230 auto optional_type =
1231 attrs_->GetOptional<OpAttrType>({attr_name.data(), attr_name.size()});
1232 if (!optional_type.hasValue()) return false;
1233 *result = tensorflow::tfd::ConvertToTfDataType(optional_type.getValue());
1234 return true;
1235 }
1236
OperationInterface(ContextInterface * context)1237 OperationInterface::OperationInterface(ContextInterface* context)
1238 : ImmediateExecutionOperation(kTfrt),
1239 op_attrs_(&attrs_, &fallback_attrs_),
1240 context_(context) {}
1241
Reset(const char * op,const char * raw_device_name)1242 tensorflow::Status OperationInterface::Reset(const char* op,
1243 const char* raw_device_name) {
1244 op_name_ = op;
1245 args_.clear();
1246 attrs_.Reset();
1247 custom_device_tensor_handle_count_ = 0;
1248 op_def_ = nullptr;
1249 fallback_attrs_.Reset(op);
1250 stack_trace_.reset();
1251 op_ = nullptr;
1252 function_state_.reset();
1253 tensorflow::Status s = tensorflow::OpDefForOp(op_name_, &op_def_);
1254 is_function_ = !s.ok();
1255 return SetDeviceName(raw_device_name);
1256 }
1257
Execute(absl::Span<tensorflow::AbstractTensorHandle * > retvals,int * num_retvals)1258 tensorflow::Status OperationInterface::Execute(
1259 absl::Span<tensorflow::AbstractTensorHandle*> retvals, int* num_retvals) {
1260 if (custom_device_tensor_handle_count_ > 0) {
1261 return tensorflow::errors::InvalidArgument(
1262 "Cannot execute ops that conntains unsupported arg in TFRT.");
1263 }
1264
1265 TF_RETURN_IF_ERROR(Initialize());
1266 assert(op_ != nullptr || function_state_);
1267 auto* corert = context_->GetCoreRuntime();
1268 auto* chain = context_->GetChain();
1269 auto* host = corert->GetHostContext();
1270 SmallVector<TensorHandle, 8> th_args;
1271 th_args.reserve(args_.size());
1272
1273 SmallVector<TensorHandle, 8> result_ths;
1274 result_ths.resize(*num_retvals);
1275
1276 if (function_state_) {
1277 // Set up arguments. Check argument dtype synchronously if available.
1278 auto arg_types = function_state_->GetArgTypes();
1279 if (args_.size() != arg_types.size()) {
1280 return tensorflow::errors::InvalidArgument("Expects ", arg_types.size(),
1281 " arguments, but ",
1282 args_.size(), " is provided");
1283 }
1284 auto args_size = args_.size();
1285 for (auto i = 0; i < args_size; ++i) {
1286 th_args.push_back(down_cast<TensorHandleInterface*>(args_[i].get())
1287 ->Handle()
1288 .CopyRef());
1289 // TODO(b/173556766): This dtype check is only needed for corert lowering.
1290 // In native lowering, compiler should obtain the argument dtype
1291 // information from FunctionBody directly and lower the op to the native
1292 // kernel that accepts the specified dtype.
1293 if (th_args[i].IsMetadataAvailable()) {
1294 auto arg_dtype = th_args[i].GetAvailableMetadata().dtype;
1295 if (arg_dtype != arg_types[i]) {
1296 return tensorflow::errors::InvalidArgument(
1297 "Expects arg[", i, "] to be ", arg_types[i], " but ", arg_dtype,
1298 " is provided");
1299 }
1300 }
1301 }
1302
1303 RCReference<RequestContext> request_ctx;
1304 TF_RETURN_IF_ERROR(context_->BuildFunctionRequestContext(
1305 function_state_->GetRunnerTable(), &request_ctx));
1306
1307 ExecutionContext exec_ctx{std::move(request_ctx),
1308 abort_location_handler_.GetCurrentLocation()};
1309 // Execute the function.
1310 function_state_->GetFunc()(exec_ctx, th_args, OpAttrsRef(attrs_),
1311 result_ths, chain);
1312 } else {
1313 RCReference<RequestContext> request_ctx;
1314 TF_RETURN_IF_ERROR(context_->BuildOpRequestContext(&request_ctx));
1315
1316 ExecutionContext exec_ctx{std::move(request_ctx),
1317 abort_location_handler_.GetCurrentLocation()};
1318 for (auto& arg : args_) {
1319 th_args.push_back(
1320 down_cast<TensorHandleInterface*>(arg.get())->Handle().CopyRef());
1321 }
1322 // If the CoreRuntimeOp is a native TFRT op, transfer arguments to target
1323 // device if necessary.
1324 if (!op_->IsFallback()) {
1325 // Get the target device of the arguments that we want to implicitly copy
1326 // to.
1327 auto dst_device_ref = op_->GetDeviceRef();
1328
1329 for (auto& th_arg : th_args) {
1330 th_arg = th_arg.TransferTo(exec_ctx, dst_device_ref.CopyRef(),
1331 op_->GetTensorType());
1332 }
1333 }
1334
1335 (*op_)(exec_ctx, th_args, OpAttrsRef(attrs_), result_ths, chain);
1336 }
1337
1338 tensorflow::Status s = tensorflow::Status::OK();
1339
1340 if (TF_PREDICT_FALSE(!this->context_->is_async() && !chain->IsAvailable()))
1341 host->Await({chain->CopyRCRef()});
1342
1343 if (TF_PREDICT_FALSE(chain->IsError())) {
1344 s = CreateTfErrorStatus(chain->GetError());
1345 // TODO(tfrt-devs): Assess if we need a explicit API to clear error.
1346 *chain = GetReadyChain(host);
1347 }
1348
1349 for (size_t i = 0, e = result_ths.size(); i != e; ++i) {
1350 auto& th_ref = result_ths[i];
1351 if (TF_PREDICT_FALSE(!this->context_->is_async() &&
1352 !th_ref.GetAsyncTensor()->IsAvailable()))
1353 host->Await(FormRef(th_ref.GetAsyncTensor()));
1354
1355 // NOTE(fishx): In async mode, we won't report error synchronously even
1356 // though it is possible in TFRT. This is intended to match behavior in
1357 // current TF. However, in the future, we may want to update this
1358 // behavior since synchronous error may improve user experience in async
1359 // mode.
1360 if (TF_PREDICT_FALSE(!this->context_->is_async() &&
1361 th_ref.GetAsyncTensor()->IsError() && s.ok()))
1362 s = CreateTfErrorStatus(th_ref.GetAsyncTensor()->GetError());
1363
1364 retvals[i] =
1365 new TensorHandleInterface(Value(std::move(result_ths[i])), corert);
1366 }
1367
1368 return s;
1369 }
1370
Initialize()1371 tensorflow::Status OperationInterface::Initialize() {
1372 CoreRuntime* corert = context_->GetCoreRuntime();
1373 if (!is_function_) {
1374 // Obtain input arguments' dtype attrs as part of the cache key.
1375 llvm::SmallVector<string_view, 4> dtypes;
1376 attrs_.IterateEntries([&](const OpAttrsRawEntry& entry) {
1377 if (entry.type == OpAttrType::DTYPE && !entry.IsArray())
1378 dtypes.push_back(
1379 GetNameString(*static_cast<const OpAttrType*>(entry.GetData())));
1380 });
1381
1382 OpHandler* op_handler = nullptr;
1383 TF_RETURN_IF_ERROR(
1384 context_->SelectOpHandlerFromArguments(*this, &op_handler));
1385 Expected<CoreRuntimeOp*> expected_op = context_->GetOpCache().GetOrAddOp(
1386 op_name_, op_handler, device_name_, dtypes, this);
1387 if (!expected_op) {
1388 return tensorflow::errors::InvalidArgument(
1389 StrCat("Cannot obtain CoreRuntimeOp: ", op_name_,
1390 " on device: ", device_name_, expected_op.takeError()));
1391 }
1392 op_ = expected_op.get();
1393 // Update device name since op_handler_selecter may choose an op_handler
1394 // that's different from what the user specifies.
1395 device_name_ = op_->DeviceName().str();
1396 return tensorflow::Status::OK();
1397 }
1398
1399 bool compile_with_xla = false;
1400 GetFuncAttr(attrs_, op_name_, *context_->GetEagerContext()->FuncLibDef(),
1401 tensorflow::kXlaMustCompileAttr, &compile_with_xla);
1402 // If the function has compile_with_xla==true, we will use RuntimeFallback
1403 // to execute it, since TFRT does not support xla yet.
1404 // TODO(tfrt-devs): Native support of compile_with_xla.
1405 if (compile_with_xla) {
1406 Expected<CoreRuntimeOp*> expected_op =
1407 context_->GetOpCache().GetOrAddXlaOp(op_name_, context_);
1408 if (!expected_op) {
1409 return tensorflow::errors::NotFound(
1410 StrCat("Cannot initialize xla function ", op_name_,
1411 " on fallback op handler.", expected_op.takeError()));
1412 }
1413 op_ = expected_op.get();
1414 return tensorflow::Status::OK();
1415 }
1416
1417 // Note(fishx): We need eager context for now because we need
1418 // FunctionLibraryDefinition to convert FunctionDef to MLIR TF dialect. In
1419 // the future, when we can generate MLIR from TF Python, we should get rid of
1420 // this.
1421 // FunctionDef -> BEF.
1422 // Look up the cache. Compile BEF and insert to cache if miss.
1423 tensorflow::DeviceSet dev_set;
1424 const DeviceMgr* device_mgr = context_->GetEagerContext()->local_device_mgr();
1425 if (device_mgr == nullptr)
1426 return tensorflow::errors::NotFound("Cannot find device manager");
1427 // TODO(tfrt-devs): support remote devices in TFRT.
1428 for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
1429 if (context_->GetDistributedManager() != nullptr &&
1430 context_->UseTfrtDistributedRuntime()) {
1431 down_cast<DistributedManagerContextInterface*>(
1432 context_->GetDistributedManager())
1433 ->PopulateRemoteDevices(&dev_set);
1434 }
1435 FunctionCache::FunctionCacheResult result;
1436
1437 tensorflow::TfrtFunctionCompileOptions compile_options;
1438 // TODO(tfrt-devs): Formalize the convention of converting TF's device name
1439 // to MLIR compilation option's device name.
1440 std::string tf_device_name = device_name_;
1441 compile_options.default_device = std::move(tf_device_name);
1442 // TODO(b/172659131): Do not use TFRT native ops for TF integration for now.
1443 // Re-enable once we have a concrete plan to implement feature complete
1444 // TFRT native ops (kernels).
1445 compile_options.enable_native_ops = false;
1446
1447 if (fallback_attrs_.NumAttributes() > 0) {
1448 const auto& ndef = NodeDef();
1449 // TODO(tfrt-devs): If we are to create more attributes, consider packing
1450 // them into a proto.
1451 {
1452 const auto& it = ndef.attr().find(kEnableNativeOpsAttr);
1453 if (it != ndef.attr().end()) {
1454 compile_options.enable_native_ops = it->second.b();
1455 }
1456 }
1457
1458 {
1459 const auto& it = ndef.attr().find(kEnableGrapplerAttr);
1460 if (it != ndef.attr().end()) {
1461 compile_options.enable_grappler = it->second.b();
1462 }
1463 }
1464 }
1465
1466 tfrt::SmallVector<const tfrt::Device*, 4> input_devices;
1467 input_devices.reserve(args_.size());
1468 for (auto& arg : args_) {
1469 input_devices.push_back(down_cast<TensorHandleInterface*>(arg.get())
1470 ->Handle()
1471 .GetAvailableDevice()
1472 .get());
1473 }
1474 TF_RETURN_IF_ERROR(context_->GetFunctionCache().GetOrAddFunction(
1475 op_name_, device_name_, dev_set, context_->GetEagerContext(), corert,
1476 /*request_ctx_fn=*/
1477 [this](tensorflow::tfd::OpKernelRunnerTable* runner_table,
1478 RCReference<RequestContext>* request_ctx) {
1479 return context_->BuildFunctionRequestContext(runner_table, request_ctx);
1480 },
1481 abort_location_handler_.GetCurrentLocation(), compile_options,
1482 input_devices, &result));
1483 // TODO(tfrt-devs): Avoid calling EagerContext::ShouldStoreGraphs().
1484 if (result.is_cache_miss &&
1485 context_->GetEagerContext()->ShouldStoreGraphs()) {
1486 TF_RETURN_IF_ERROR(context_->RunMetadataRecordFunction(op_name_));
1487 }
1488 function_state_ = std::move(result.function_state);
1489 return tensorflow::Status::OK();
1490 }
1491
SetDeviceName(const char * name)1492 tensorflow::Status OperationInterface::SetDeviceName(const char* name) {
1493 if (op_ && name != device_name_) {
1494 return tensorflow::errors::Internal(
1495 "Failed to update device name. Right now TFRT cannot update device "
1496 "name of a fallback op if it is initialized.");
1497 }
1498 device_name_ = name ? name : "";
1499 return tensorflow::Status::OK();
1500 }
1501
AddInput(tensorflow::AbstractTensorHandle * input)1502 tensorflow::Status OperationInterface::AddInput(
1503 tensorflow::AbstractTensorHandle* input) {
1504 tensorflow::ImmediateExecutionTensorHandle* h =
1505 down_cast<tensorflow::ImmediateExecutionTensorHandle*>(input);
1506 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1507 if (tensorflow::CustomDeviceTensorHandle::classof(h)) {
1508 custom_device_tensor_handle_count_++;
1509 }
1510 h->Ref();
1511 args_.push_back(
1512 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1513 h));
1514 return tensorflow::Status::OK();
1515 }
1516
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)1517 tensorflow::Status OperationInterface::SetInput(
1518 size_t index, tensorflow::ImmediateExecutionTensorHandle* input) {
1519 if (index >= args_.size()) {
1520 return tensorflow::errors::InvalidArgument("Index >= inputs.size: %d >= %d",
1521 index, args_.size());
1522 }
1523 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
1524 if (tensorflow::CustomDeviceTensorHandle::classof(args_[index].get())) {
1525 custom_device_tensor_handle_count_--;
1526 }
1527 if (tensorflow::CustomDeviceTensorHandle::classof(input)) {
1528 custom_device_tensor_handle_count_++;
1529 }
1530 input->Ref();
1531 args_[index] =
1532 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>(
1533 input);
1534 return tensorflow::Status::OK();
1535 }
1536
AddInputList(absl::Span<tensorflow::AbstractTensorHandle * const> inputs)1537 tensorflow::Status OperationInterface::AddInputList(
1538 absl::Span<tensorflow::AbstractTensorHandle* const> inputs) {
1539 return tensorflow::errors::Unimplemented(
1540 "Unimplemented OperationInterface::AddInputList");
1541 }
1542
1543 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const>
GetInputs() const1544 OperationInterface::GetInputs() const {
1545 return absl::MakeSpan(
1546 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
1547 args_.data()),
1548 args_.size());
1549 }
1550
SetAttrString(const char * attr_name,const char * data,size_t length)1551 tensorflow::Status OperationInterface::SetAttrString(const char* attr_name,
1552 const char* data,
1553 size_t length) {
1554 fallback_attrs_.Set(attr_name, tensorflow::StringPiece(data, length));
1555 if (attrs_.SetString(attr_name, string_view(data, length)))
1556 return tensorflow::Status::OK();
1557 return tensorflow::errors::Internal(
1558 "OperationInterface::SetAttrString failed");
1559 }
1560
SetAttrInt(const char * attr_name,int64_t value)1561 tensorflow::Status OperationInterface::SetAttrInt(const char* attr_name,
1562 int64_t value) {
1563 fallback_attrs_.Set(attr_name, static_cast<int64_t>(value));
1564 if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1565 return tensorflow::errors::Internal("OperationInterface::SetAttrInt failed");
1566 }
1567
SetAttrFloat(const char * attr_name,float value)1568 tensorflow::Status OperationInterface::SetAttrFloat(const char* attr_name,
1569 float value) {
1570 fallback_attrs_.Set(attr_name, value);
1571 if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1572 return tensorflow::errors::Internal(
1573 "OperationInterface::SetAttrFloat failed");
1574 }
1575
SetAttrBool(const char * attr_name,bool value)1576 tensorflow::Status OperationInterface::SetAttrBool(const char* attr_name,
1577 bool value) {
1578 fallback_attrs_.Set(attr_name, value);
1579 if (attrs_.Set(attr_name, value)) return tensorflow::Status::OK();
1580 return tensorflow::errors::Internal("OperationInterface::SetAttrBool failed");
1581 }
1582
SetAttrType(const char * attr_name,tensorflow::DataType value)1583 tensorflow::Status OperationInterface::SetAttrType(const char* attr_name,
1584 tensorflow::DataType value) {
1585 fallback_attrs_.Set(attr_name, value);
1586 if (value == tensorflow::DT_INVALID) {
1587 return tensorflow::errors::InvalidArgument(
1588 "OperationInterface::SetAttrType failed to set DT_INVALID");
1589 }
1590 if (attrs_.Set(attr_name,
1591 tfrt::GetOpAttrTypeFromDType(
1592 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(value))))
1593 return tensorflow::Status::OK();
1594 // TODO(fishx): Remove this workaround once we support all dtype in TF.
1595 // This is fine for now since attribute "T", "U", "Tidx" is not used by TFRT
1596 // native ops.
1597 if (std::strcmp(attr_name, "T") == 0 || std::strcmp(attr_name, "U") == 0 ||
1598 std::strcmp(attr_name, "Tidx") == 0) {
1599 return tensorflow::Status::OK();
1600 }
1601 return tensorflow::errors::Internal("OperationInterface::SetAttrType failed");
1602 }
1603
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)1604 tensorflow::Status OperationInterface::SetAttrShape(const char* attr_name,
1605 const int64_t* dims,
1606 const int num_dims) {
1607 // NOTE: This is copied from EagerOperation::SetAttrShape.
1608 // TODO(b/154554118): Remove the duplication.
1609 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1610 return tensorflow::errors::InvalidArgument(
1611 "Value specified for `", attr_name, "` has ", num_dims,
1612 " dimensions which is over the limit of ",
1613 tensorflow::TensorShape::MaxDimensions(), ".");
1614 }
1615
1616 tensorflow::TensorShapeProto proto;
1617 size_t offset;
1618 if (num_dims < 0) {
1619 proto.set_unknown_rank(true);
1620
1621 // Set unranked ShapeAttr.
1622 offset = bef_attr_encoder_.EncodeUnrankedShapeAttr();
1623 } else {
1624 for (int d = 0; d < num_dims; ++d) {
1625 proto.add_dim()->set_size(dims[d]);
1626 }
1627
1628 // Set RankedShapeAttr.
1629 offset = bef_attr_encoder_.EncodeRankedShapeAttr(
1630 llvm::makeArrayRef(dims, num_dims));
1631 }
1632 fallback_attrs_.Set(attr_name, proto);
1633
1634 auto buf = bef_attr_encoder_.TakeResult();
1635 tfrt::ShapeAttr shape_attr(buf.data() + offset);
1636 // TODO(tfrt-devs): Avoid the copy.
1637 if (attrs_.Set(attr_name, shape_attr)) return tensorflow::Status::OK();
1638
1639 return tensorflow::errors::Internal(
1640 "OperationInterface::SetAttrShape failed");
1641 }
1642
SetAttrFunction(const char * attr_name,const tensorflow::AbstractOperation * value)1643 tensorflow::Status OperationInterface::SetAttrFunction(
1644 const char* attr_name, const tensorflow::AbstractOperation* value) {
1645 auto* value_operation = down_cast<const OperationInterface*>(value);
1646 // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1647 // Consider removing this and rely on TFRT OpAttrs.
1648 tensorflow::AttrValue attr_value;
1649 tensorflow::NameAttrList* func = attr_value.mutable_func();
1650 func->set_name(value->Name());
1651 fallback_attrs_.Set(attr_name, attr_value);
1652
1653 if (attrs_.SetFunc(attr_name, {string_view(value_operation->Name())}))
1654 return tensorflow::Status::OK();
1655
1656 return tensorflow::errors::Internal(
1657 "OperationInterface::SetAttrFunction failed");
1658 }
1659
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)1660 tensorflow::Status OperationInterface::SetAttrFunctionName(
1661 const char* attr_name, const char* data, size_t length) {
1662 // TODO(b/165412867): Set fallback_attrs_ for eager device placement.
1663 // Consider removing this and rely on TFRT OpAttrs.
1664 tensorflow::AttrValue attr_value;
1665 tensorflow::NameAttrList* func = attr_value.mutable_func();
1666 func->set_name(data);
1667 fallback_attrs_.Set(attr_name, attr_value);
1668
1669 if (attrs_.SetFunc(attr_name, {data})) return tensorflow::Status::OK();
1670
1671 return tensorflow::errors::Internal(
1672 "OperationInterface::SetAttrFunctionName failed");
1673 }
1674
SerializeTFETensorToDenseAttr(tensorflow::AbstractTensorInterface * tensor,tfrt::BefAttrEncoder * encoder)1675 static size_t SerializeTFETensorToDenseAttr(
1676 tensorflow::AbstractTensorInterface* tensor,
1677 tfrt::BefAttrEncoder* encoder) {
1678 std::vector<uint8_t> data;
1679
1680 const auto element_type =
1681 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(tensor->Type());
1682 SmallVector<int64_t, 4> shape;
1683 for (int i = 0; i < tensor->NumDims(); ++i) {
1684 shape.push_back(tensor->Dim(i));
1685 }
1686 auto elements = llvm::makeArrayRef(
1687 reinterpret_cast<const uint8_t*>(tensor->Data()), tensor->ByteSize());
1688 return encoder->EncodeDenseAttr(static_cast<DType>(element_type), shape,
1689 elements);
1690 }
1691
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)1692 tensorflow::Status OperationInterface::SetAttrTensor(
1693 const char* attr_name, tensorflow::AbstractTensorInterface* tensor) {
1694 tfrt::BefAttrEncoder encoder;
1695 const size_t offset = SerializeTFETensorToDenseAttr(tensor, &encoder);
1696 auto buffer = encoder.TakeResult();
1697 DenseAttr dense_attr(buffer.data() + offset);
1698 if (attrs_.Set(attr_name, dense_attr)) return tensorflow::Status::OK();
1699
1700 return tensorflow::errors::Internal(
1701 "OperationInterface::SetAttrTensor failed");
1702 }
1703
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1704 tensorflow::Status OperationInterface::SetAttrStringList(
1705 const char* attr_name, const void* const* values, const size_t* lengths,
1706 int num_values) {
1707 std::vector<tensorflow::StringPiece> v(num_values);
1708 for (int i = 0; i < num_values; ++i) {
1709 v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1710 lengths[i]);
1711 }
1712 fallback_attrs_.Set(attr_name, v);
1713
1714 tfrt::BefAttrEncoder encoder;
1715 const size_t offset =
1716 encoder.EncodeStringListAttr(values, lengths, num_values);
1717 auto buf = encoder.TakeResult();
1718 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1719 // TODO(tfrt-devs): Avoid the copy.
1720 if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1721
1722 return tensorflow::errors::Internal(
1723 "OperationInterface::SetAttrStringList failed");
1724 }
1725
SetAttrFloatList(const char * attr_name,const float * values,int num_values)1726 tensorflow::Status OperationInterface::SetAttrFloatList(const char* attr_name,
1727 const float* values,
1728 int num_values) {
1729 fallback_attrs_.Set(
1730 attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1731
1732 if (attrs_.SetArray(attr_name, tfrt::ArrayRef<float>(values, num_values)))
1733 return tensorflow::Status::OK();
1734 return tensorflow::errors::Internal(
1735 "OperationInterface::SetAttrFloatList failed");
1736 }
1737
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)1738 tensorflow::Status OperationInterface::SetAttrIntList(const char* attr_name,
1739 const int64_t* values,
1740 int num_values) {
1741 fallback_attrs_.Set(
1742 attr_name, tensorflow::gtl::ArraySlice<const int64_t>(
1743 reinterpret_cast<const int64_t*>(values), num_values));
1744
1745 if (attrs_.SetArray(attr_name, tfrt::ArrayRef<int64_t>(values, num_values)))
1746 return tensorflow::Status::OK();
1747
1748 return tensorflow::errors::Internal(
1749 "OperationInterface::SetAttrIntList failed");
1750 }
1751
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)1752 tensorflow::Status OperationInterface::SetAttrTypeList(
1753 const char* attr_name, const tensorflow::DataType* values, int num_values) {
1754 fallback_attrs_.Set(attr_name,
1755 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1756 values, num_values));
1757 // Convert to OpAttrType first.
1758 SmallVector<tfrt::DType, 4> tfrt_dtypes;
1759 tfrt_dtypes.reserve(num_values);
1760 for (int i = 0; i < num_values; ++i) {
1761 tfrt_dtypes.push_back(
1762 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(values[i]));
1763 }
1764
1765 if (attrs_.SetRaw(attr_name, tfrt_dtypes.data(), tfrt::OpAttrType::DTYPE,
1766 num_values, OpAttrsRawEntryType::kArray))
1767 return tensorflow::Status::OK();
1768
1769 return tensorflow::errors::Internal(
1770 "OperationInterface::SetAttrTypeList failed");
1771 }
1772
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)1773 tensorflow::Status OperationInterface::SetAttrBoolList(
1774 const char* attr_name, const unsigned char* values, int num_values) {
1775 std::unique_ptr<bool[]> b(new bool[num_values]);
1776 for (int i = 0; i < num_values; ++i) {
1777 b[i] = values[i];
1778 }
1779 fallback_attrs_.Set(
1780 attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1781
1782 // Convert to bool first.
1783 SmallVector<bool, 4> bool_array;
1784 bool_array.reserve(num_values);
1785 for (int i = 0; i < num_values; ++i) {
1786 bool_array.push_back(static_cast<bool>((values[i])));
1787 }
1788 if (attrs_.SetArray(attr_name,
1789 tfrt::ArrayRef<bool>(bool_array.data(), num_values)))
1790 return tensorflow::Status::OK();
1791
1792 return tensorflow::errors::Internal(
1793 "OperationInterface::SetAttrBoolList failed");
1794 }
1795
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)1796 tensorflow::Status OperationInterface::SetAttrShapeList(const char* attr_name,
1797 const int64_t** dims,
1798 const int* num_dims,
1799 int num_values) {
1800 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1801 new tensorflow::TensorShapeProto[num_values]);
1802 for (int i = 0; i < num_values; ++i) {
1803 const auto num_dims_i = num_dims[i];
1804
1805 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1806 return tensorflow::errors::InvalidArgument(
1807 StrCat("Value specified for `", attr_name, "` has ", num_dims_i,
1808 " dimensions which is over the limit of ",
1809 tensorflow::TensorShape::MaxDimensions(), "."));
1810 }
1811 if (num_dims_i < 0) {
1812 proto[i].set_unknown_rank(true);
1813 } else {
1814 const int64_t* dims_i = dims[i];
1815 auto proto_i = &proto[i];
1816 for (int d = 0; d < num_dims_i; ++d) {
1817 proto_i->add_dim()->set_size(dims_i[d]);
1818 }
1819 }
1820 }
1821 fallback_attrs_.Set(attr_name,
1822 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1823 proto.get(), num_values));
1824
1825 BefAttrEncoder encoder;
1826 const size_t offset = encoder.EncodeShapeListAttr(dims, num_dims, num_values);
1827 auto buf = encoder.TakeResult();
1828 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1829 if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1830
1831 return tensorflow::errors::Internal(
1832 "OperationInterface::SetAttrShapeList failed");
1833 }
1834
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)1835 tensorflow::Status OperationInterface::SetAttrFunctionList(
1836 const char* attr_name, absl::Span<const AbstractOperation*> values) {
1837 size_t num_values = values.size();
1838 std::vector<const void*> func_attrs(num_values);
1839 std::vector<size_t> lengths(num_values);
1840
1841 for (int i = 0; i < num_values; ++i) {
1842 auto* value_operation = down_cast<const OperationInterface*>(values[i]);
1843 lengths[i] = value_operation->Name().length();
1844 func_attrs[i] = value_operation->Name().c_str();
1845 }
1846
1847 // Encode the array of function attributes with BEF typed attribute encoder to
1848 // an aggregated attribute.
1849 BefAttrEncoder encoder;
1850 const size_t offset =
1851 encoder.EncodeFuncListAttr(func_attrs.data(), lengths.data(), num_values);
1852 auto buf = encoder.TakeResult();
1853 tfrt::AggregateAttr aggr_attr(buf.data() + offset);
1854 if (attrs_.Set(attr_name, aggr_attr)) return tensorflow::Status::OK();
1855
1856 return tensorflow::errors::Internal(
1857 "OperationInterface::SetAttrFunctionList failed");
1858 }
1859
InputLength(const char * input_name,int * length)1860 tensorflow::Status OperationInterface::InputLength(const char* input_name,
1861 int* length) {
1862 return tensorflow::errors::Unimplemented(
1863 "Unimplemented OperationInterface::InputLength");
1864 }
1865
OutputLength(const char * output_name,int * length)1866 tensorflow::Status OperationInterface::OutputLength(const char* output_name,
1867 int* length) {
1868 return tensorflow::errors::Unimplemented(
1869 "Unimplemented OperationInterface::OutputLength");
1870 }
1871
GetOpAttrs() const1872 const tensorflow::AbstractOpAttrs* OperationInterface::GetOpAttrs() const {
1873 return &op_attrs_;
1874 }
1875
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)1876 void OperationInterface::AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) {
1877 auto* tfrt_op_attrs = down_cast<const OpAttrsInterface*>(op_attrs);
1878 tfrt_op_attrs->GetAttrs()->IterateEntries(
1879 [this](const OpAttrsRawEntry& entry) {
1880 attrs_.SetRaw(entry.name, entry.GetData(), entry.type,
1881 entry.element_count, entry.entry_type);
1882 });
1883 fallback_attrs_.CopyAttributes(*tfrt_op_attrs->GetFallbackAttrs());
1884 }
1885
MaybeInferInputAttrs()1886 void OperationInterface::MaybeInferInputAttrs() {
1887 if (!op_def_) return;
1888 for (int i = 0; i < args_.size(); i++) {
1889 auto* handle = args_[i].get();
1890 const auto& input_def = op_def_->input_arg(i);
1891 if (!input_def.number_attr().empty() ||
1892 !input_def.type_list_attr().empty()) {
1893 // Some clients that are still setting their input attributes manually are
1894 // adding input list to their op by calling `TFE_OpAddInput` for each of
1895 // its elements instead of calling `TFE_OpAddInputList`. When this
1896 // happens, we cannot detect the end of such list, thus lose track of the
1897 // input arguments in the op definition. To guarantee backward
1898 // compatibility with those clients, disable automatic inference in this
1899 // case.
1900 return;
1901 }
1902 const std::string& type_attr = input_def.type_attr();
1903 if (!type_attr.empty()) {
1904 bool success = attrs_.Set(
1905 type_attr, tfrt::GetOpAttrTypeFromDType(
1906 tensorflow::tfd::ConvertTfDataTypeToBefAttrType(
1907 handle->DataType())));
1908 if (success) {
1909 fallback_attrs_.Set(type_attr, handle->DataType());
1910 }
1911 }
1912 }
1913 }
1914
1915 } // namespace tf
1916 } // namespace tfrt
1917