1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/types/optional.h"
23 #include "tensorflow/c/eager/abstract_op_attrs.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/immediate_execution_context.h"
26 #include "tensorflow/c/eager/immediate_execution_operation.h"
27 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
28 #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
29 #include "tensorflow/c/tensor_interface.h"
30 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
31 #include "tensorflow/core/common_runtime/eager/context.h"
32 #include "tensorflow/core/framework/cancellation.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/tfrt/eager/function_cache.h"
40 #include "tensorflow/core/tfrt/eager/op_cache.h"
41 #include "tensorflow/core/tfrt/eager/tfrt_context.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tfrt/bef_converter/bef_attr_encoder.h" // from @tf_runtime
44 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
45 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
46 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
47 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
48 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
49 #include "tfrt/host_context/value.h" // from @tf_runtime
50 #include "tfrt/support/aligned_buffer.h" // from @tf_runtime
51 #include "tfrt/support/forward_decls.h" // from @tf_runtime
52 #include "tfrt/support/ref_count.h" // from @tf_runtime
53 #include "tfrt/tensor/tensor.h" // from @tf_runtime
54
55 namespace tfrt {
56
57 class CoreRuntime;
58 class CoreRuntimeOp;
59 class DenseHostTensor;
60 class OpHandler;
61 class TensorHandle;
62 class TensorMetadata;
63
64 namespace tf {
65 class EagerOpHandlerSelector;
66
67 class ContextInterface : public tensorflow::ImmediateExecutionContext {
68 public:
69 ContextInterface(
70 const tensorflow::SessionOptions& opts,
71 tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
72 bool is_async, bool use_tfrt_distributed_runtime);
73 ~ContextInterface() override;
74
Release()75 void Release() override { delete this; }
76
77 tensorflow::AbstractTensorInterface* CreateInt64Scalar(
78 int64_t value) override;
79 tensorflow::AbstractTensorInterface* CreateUint64Scalar(
80 uint64_t value) override;
81 tensorflow::AbstractTensorInterface* CreateInt32Scalar(
82 int32_t value) override;
83 tensorflow::AbstractTensorInterface* CreateFloatScalar(float value) override;
84 tensorflow::AbstractTensorInterface* CreateDoubleScalar(
85 double value) override;
86 tensorflow::AbstractTensorInterface* CreateHalfScalar(
87 Eigen::half value) override;
88 tensorflow::AbstractTensorInterface* CreateStringScalar(
89 tensorflow::tstring value) override;
90 tensorflow::AbstractTensorInterface* CreateComplex128Scalar(
91 tensorflow::complex128 value) override;
92 tensorflow::AbstractTensorInterface* CreateBoolScalar(bool value) override;
93
94 tensorflow::AbstractTensorInterface* CreateTensor(
95 tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) override;
96 tensorflow::AbstractTensorInterface* CreateTensor(
97 tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
98 size_t len, MemoryReleaser memory_releaser,
99 void* memory_releaser_arg) override;
100
101 tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandle(
102 tensorflow::AbstractTensorInterface* t) override;
103 // Create an abstract tensor handle from tensorflow::Tensor.
104 tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
105 tensorflow::Tensor& t, const char* d_name) override;
106
107 // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
108 tensorflow::ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
109 tensorflow::ImmediateExecutionTensorHandle* handle) override;
110
111 tensorflow::ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
112 tensorflow::ImmediateExecutionTensorHandle* handle,
113 const char* device_name, tensorflow::Status* status) override;
114
115 tensorflow::ImmediateExecutionOperation* CreateOperation() override;
116 tensorflow::Status RegisterFunction(tensorflow::AbstractFunction*) override;
117
118 tensorflow::CustomDeviceOpHandler& GetCustomDeviceOpHandler() override;
119
120 tensorflow::Status RegisterCustomDevice(
121 const std::string& name,
122 std::unique_ptr<tensorflow::CustomDevice> device) override;
123
124 tensorflow::FunctionLibraryDefinition* FuncLibDef() override;
125
126 void SetReuseRendezvousForFunctions(
127 bool reuse_rendezvous_for_functions) override;
128
129 void ResetGlobalRendezvousForFunction() override;
130
131 bool UsesTFRT() override;
132
133 void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
134
ListLocalTfDevices()135 std::vector<tensorflow::Device*> ListLocalTfDevices() override {
136 return context_.GetEagerContext()->local_device_mgr()->ListDevices();
137 }
138
ListAllTfDevices()139 std::vector<tensorflow::Device*> ListAllTfDevices() override {
140 return context_.GetEagerContext()->ListAllTfDevices();
141 }
142
143 tensorflow::Status AddDevices(
144 std::vector<std::unique_ptr<tensorflow::Device>> devices) override;
145
146 void ClearCachesAndThreadExecutors() override;
147 void StartStep() override;
148 void EndStep() override;
149
AsyncWait()150 tensorflow::Status AsyncWait() override {
151 TF_RETURN_IF_ERROR(GetEagerContext()->AsyncWait());
152 GetHostContext()->Quiesce();
153 return ::tensorflow::OkStatus();
154 }
155
156 tensorflow::Status AddFunctionDef(
157 const tensorflow::FunctionDef& fdef) override;
158 tensorflow::Status AddFunctionDefWithStackTraces(
159 const tensorflow::FunctionDef& fdef,
160 const tensorflow::StackTracesMap& stack_traces) override;
161 std::vector<std::string> ListFunctionNames() override;
162 tensorflow::Status RemoveFunction(const std::string& func) override;
163 const tensorflow::FunctionDef* FindFunctionDef(
164 const std::string& name) const override;
165
166 const tensorflow::DeviceNameUtils::ParsedName& HostCPUParsedName()
167 const override;
168 const std::string& HostCPUName() const override;
169
SetAllowSoftPlacement(bool enable)170 void SetAllowSoftPlacement(bool enable) override {
171 // TODO(tfrt-devs): Move this flag to a common place that can be shared
172 // by current TF and TFRT.
173 GetEagerContext()->SetAllowSoftPlacement(enable);
174 }
SetShouldStoreGraphs(bool value)175 void SetShouldStoreGraphs(bool value) override {
176 GetEagerContext()->SetShouldStoreGraphs(value);
177 }
178
179 tensorflow::Status EnableCollectiveOps(
180 const tensorflow::ServerDef& server_def) override;
181
182 std::unique_ptr<tensorflow::RunMetadata> ExportRunMetadata() override;
183
184 // Find the FunctionDef by the given name and record it in RunMetadata.
185 tensorflow::Status RunMetadataRecordFunction(const std::string& func_name);
186
SetLogDevicePlacement(bool enable)187 void SetLogDevicePlacement(bool enable) override {
188 // TODO(tfrt-devs): Move this flag to a common place that can be shared
189 // by current TF and TFRT.
190 GetEagerContext()->SetLogDevicePlacement(enable);
191 }
192
SetRunEagerOpAsFunction(bool enable)193 void SetRunEagerOpAsFunction(bool enable) override {
194 // TODO(tfrt-devs): Move this flag to a common place that can be shared
195 // by current TF and TFRT.
196 GetEagerContext()->SetRunEagerOpAsFunction(enable);
197 }
198
SetJitCompileRewrite(bool enable)199 void SetJitCompileRewrite(bool enable) override {
200 // TODO(tfrt-devs): Move this flag to a common place that can be shared
201 // by current TF and TFRT.
202 GetEagerContext()->SetJitCompileRewrite(enable);
203 }
204
Executor()205 tensorflow::EagerExecutor& Executor() override {
206 return GetEagerContext()->Executor();
207 }
208 void SetExecutorForThread(tensorflow::EagerExecutor* executor) override;
209
SetThreadLocalDevicePlacementPolicy(tensorflow::ContextDevicePlacementPolicy policy)210 void SetThreadLocalDevicePlacementPolicy(
211 tensorflow::ContextDevicePlacementPolicy policy) override {
212 // TODO(tfrt-devs): Move this flag to a common place that can be shared
213 // by current TF and TFRT.
214 GetEagerContext()->SetThreadLocalDevicePlacementPolicy(policy);
215 }
GetDevicePlacementPolicy()216 tensorflow::ContextDevicePlacementPolicy GetDevicePlacementPolicy()
217 const override {
218 // TODO(tfrt-devs): Move this flag to a common place that can be shared
219 // by current TF and TFRT.
220 return GetEagerContext()->GetDevicePlacementPolicy();
221 }
222
223 CoreRuntime* GetCoreRuntime();
224 tensorflow::Status BuildFunctionRequestContext(
225 tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
226 RCReference<tfrt::RequestContext>* request_context);
227 tensorflow::Status BuildOpRequestContext(
228 RCReference<tfrt::RequestContext>* request_context);
229 tensorflow::EagerContext* GetEagerContext();
230 const tensorflow::EagerContext* GetEagerContext() const;
231 TfrtContext* GetTfrtContext();
232
233 // Selects the op handler to execute the op based on the arguments. This
234 // op handler selection is cheap. But it can be nullptr even it return OK
235 // status.
236 tensorflow::Status SelectOpHandlerFromArguments(
237 const tensorflow::ImmediateExecutionOperation& op,
238 OpHandler** op_handler);
239
240 // Selects the op handler to execute the op based on NodeDef. This op handler
241 // selection is expensive. It will never return nullptr unless there is an
242 // error. Please only invoke this method when the cheap version fails.
243 tensorflow::Status SelectOpHandlerFromNodeDef(
244 const tensorflow::ImmediateExecutionOperation& op,
245 const tensorflow::NodeDef* node_def, OpHandler** op_handler);
246
247 // Returns the chain for current thread.
248 AsyncValueRef<Chain>* GetChain();
249
250 // Indicates sync or async execution.
IsAsync()251 bool IsAsync() const { return context_.IsAsync(); }
252
253 // For LLVM style RTTI.
classof(const AbstractContext * op)254 static bool classof(const AbstractContext* op) {
255 return op->getKind() == kTfrt;
256 }
257
GetFunctionCache()258 FunctionCache& GetFunctionCache() { return function_cache_; }
259
GetOpCache()260 OpCache& GetOpCache() { return op_cache_; }
261
262 OpHandler* GetFallbackOpHandler();
263
264 std::vector<std::string> GetLoggedOpsTestonly() override;
265
UseTfrtDistributedRuntime()266 bool UseTfrtDistributedRuntime() { return use_tfrt_distributed_runtime_; }
267
268 #if !defined(IS_MOBILE_PLATFORM)
SetDistributedManager(std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager> distributed)269 void SetDistributedManager(
270 std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
271 distributed) override {
272 distributed_manager_ = std::move(distributed);
273 }
274
GetDistributedManager()275 tensorflow::ImmediateExecutionDistributedManager* GetDistributedManager()
276 override {
277 if (use_tfrt_distributed_runtime_) {
278 return distributed_manager_.get();
279 } else {
280 return context_.GetEagerContext()->GetDistributedManager();
281 }
282 }
283 #endif // !IS_MOBILE_PLATFORM
284
285 private:
286 HostContext* GetHostContext();
287 ResourceContext* GetResourceContext();
288
289 Expected<OpHandler*> GetOpHandler(const char* name);
290
291 TfrtContext context_;
292
293 mutable tensorflow::mutex chain_map_mu_;
294 // TODO(chuanhao): Hook it up with C API to allow user to manage it.
295 // Each caller thread will have its own chain to dispatch ops.
296 std::unordered_map<std::thread::id, AsyncValueRef<Chain>> thread_local_chain_
297 TF_GUARDED_BY(chain_map_mu_);
298
299 std::unique_ptr<EagerOpHandlerSelector> op_handler_selector_;
300
301 // The cache that stores functions (composite CoreRuntimeOps).
302 FunctionCache function_cache_;
303
304 // The cache that stores CoreRuntimeOps. It's separate from function cache
305 // since a primitive CoreRuntimeOp is essentially a stateless function
306 // pointer, and so it doesn't need ref-count to manage its lifetime.
307 OpCache op_cache_;
308
309 mutex run_metadata_mu_;
310 std::unique_ptr<tensorflow::RunMetadata> run_metadata_
311 TFRT_GUARDED_BY(run_metadata_mu_);
312
313 // Use TFRT's implementation of distributed manager.
314 bool use_tfrt_distributed_runtime_ = false;
315
316 // A distributed manager that helps setup, update, and check liveness of
317 // member tasks in the cluster.
318 std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
319 distributed_manager_;
320 };
321
322 class TensorInterface : public tensorflow::AbstractTensorInterface {
323 public:
TensorInterface(AsyncValueRef<Tensor> t)324 explicit TensorInterface(AsyncValueRef<Tensor> t) : tensor_(std::move(t)) {}
TensorInterface(tensorflow::Tensor t)325 explicit TensorInterface(tensorflow::Tensor t) : tf_tensor_(std::move(t)) {}
~TensorInterface()326 ~TensorInterface() override {}
327
Release()328 void Release() override { delete this; }
329
330 tensorflow::DataType Type() const override;
331 int NumDims() const override;
332 int64_t Dim(int dim_index) const override;
333 int64_t NumElements() const override;
334 size_t ByteSize() const override;
335 void* Data() const override;
336 bool IsAligned() const override;
337 bool CanMove() const override;
IsTfTensor()338 bool IsTfTensor() const { return !tensor_; }
339 std::string SummarizeValue() const override;
340
341 AsyncValueRef<tfrt::Tensor> TensorRef() const;
TfTensor()342 tensorflow::Tensor& TfTensor() { return tf_tensor_; }
343
344 private:
345 AsyncValueRef<tfrt::Tensor> tensor_;
346 // NOTE(b/167608876): tensorflow::Tensor for handling non-scalar string
347 // tensors, for backward compatibility. This is a temporary workaround until
348 // we find a proper way to unify tensorflow::tstring and
349 // tfrt::StringHostTensor.
350 tensorflow::Tensor tf_tensor_;
351 };
352
353 class TensorHandleInterface
354 : public tensorflow::ImmediateExecutionTensorHandle {
355 public:
356 explicit TensorHandleInterface(Value&& v, TfrtContext* context);
357
358 explicit TensorHandleInterface(tensorflow::DataType dtype, Value&& v,
359 TfrtContext* context);
360
Release()361 void Release() override { Unref(); }
362
363 tensorflow::DataType DataType() const override;
364 tensorflow::Status TensorHandleStatus() const override;
365 tensorflow::Status Shape(
366 tensorflow::PartialTensorShape* shape) const override;
367 tensorflow::Status NumDims(int* num_dims) const override;
368 tensorflow::Status NumElements(int64_t* num_elements) const override;
369 tensorflow::Status Dim(int dim_index, int64_t* dim) const override;
370
371 // DeviceName represents the device that creates the tensor handle.
372 // Currently the same with BackingDeviceName.
373 // TODO(b/169341326): unify device behavior between current TF and TFRT.
374 const char* DeviceName(tensorflow::Status* status) const override;
375
376 // BackingDeviceName represents the device where the tensor is physically
377 // placed. DeviceName and BackingDeviceName are the same for TFRT.
378 const char* BackingDeviceName(tensorflow::Status* status) const override;
379
380 const char* DeviceType(tensorflow::Status* status) const override;
381
DeviceId(tensorflow::Status * status)382 int DeviceId(tensorflow::Status* status) const override {
383 // TODO(tfrt-devs): implement for tfrt tensor handle.
384 llvm_unreachable("unimplemented method.");
385 }
386
387 tensorflow::AbstractTensorInterface* Resolve(
388 tensorflow::Status* status) override;
389
390 // TODO(b/161897666): Figure out if we can get rid of returning a new
391 // pointer here and just use Ref().
Copy()392 tensorflow::ImmediateExecutionTensorHandle* Copy() override {
393 Ref();
394 return this;
395 }
396
Handle()397 TensorHandle Handle() { return value_.get<TensorHandle>().CopyRef(); }
398
value()399 Value* value() { return &value_; }
400
401 // For LLVM style RTTI.
classof(const tensorflow::AbstractTensorHandle * ptr)402 static bool classof(const tensorflow::AbstractTensorHandle* ptr) {
403 return ptr->getKind() == kTfrt;
404 }
405
406 private:
407 llvm::Optional<const TensorMetadata*> Metadata() const;
408
409 tensorflow::StatusOr<tensorflow::DataType> ObtainDataTypeFromMetaData(
410 const TensorMetadata*) const;
411
412 // If the tensor handle is generated as the result of a function, the datatype
413 // is known from the function output signature.
414 // Therefore, we can obtain the datatype earlier, before the function
415 // execution completes.
416 llvm::Optional<tensorflow::DataType> dtype_;
417
418 TfrtContext& context_;
419
420 // Value of tfrt::TensorHandle.
421 Value value_;
422 };
423
424 template <typename T>
TensorHandleFromInterface(T * handle)425 inline TensorHandleInterface* TensorHandleFromInterface(T* handle) {
426 return tensorflow::down_cast<TensorHandleInterface*>(handle);
427 }
428
429 // TFRT location handler class that simply prints the error and abort the
430 // program on encountering any error. It's primarily for easy debugging
431 // TODO(kkb): Handle errors probably by raising a Python exception.
432 class AbortLocationHandler final : public tfrt::LocationHandler {
433 public:
434 tfrt::Location GetCurrentLocation();
435
436 private:
DecodeLocation(tfrt::Location loc)437 tfrt::DecodedLocation DecodeLocation(tfrt::Location loc) const override {
438 // Return a dummy decoded location.
439 return {};
440 }
441 };
442
443 class OpAttrsInterface : public tensorflow::AbstractOpAttrs {
444 public:
OpAttrsInterface(const OpAttrs * attrs,tensorflow::AttrBuilder * fallback_attrs)445 explicit OpAttrsInterface(const OpAttrs* attrs,
446 tensorflow::AttrBuilder* fallback_attrs)
447 : AbstractOpAttrs(
448 tensorflow::AbstractOpAttrs::AbstractOpAttrsKind::kTfrt),
449 attrs_(attrs),
450 fallback_attrs_(fallback_attrs) {}
~OpAttrsInterface()451 ~OpAttrsInterface() override {}
452
453 void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
454 tensorflow::Status GetTypeList(
455 absl::string_view attr_name,
456 absl::InlinedVector<tensorflow::DataType, 4>* type_list) const override;
457
458 bool GetInt(absl::string_view attr_name, int64_t* result) const override;
459 bool GetFloat(absl::string_view attr_name, float* result) const override;
460 bool GetBool(absl::string_view attr_name, bool* result) const override;
461 bool GetType(absl::string_view attr_name,
462 tensorflow::DataType* result) const override;
463
GetAttrs()464 const OpAttrs* GetAttrs() const { return attrs_; }
465
GetFallbackAttrs()466 const tensorflow::AttrBuilder* GetFallbackAttrs() const {
467 return fallback_attrs_;
468 }
469
470 private:
471 // TODO(fishx): Move ownership to here.
472 const OpAttrs* attrs_;
473
474 // TODO(tfrt-devs): Remove this field and generate NameAttrList from attrs_.
475 // Today it is fine since we will set both attrs and fallback_attrs.
476 const tensorflow::AttrBuilder* fallback_attrs_;
477 };
478
479 class OperationInterface : public tensorflow::ImmediateExecutionOperation {
480 public:
481 // All arguments come from ContextInterface.
482 explicit OperationInterface(ContextInterface* context);
~OperationInterface()483 ~OperationInterface() override {}
484
Release()485 void Release() override { delete this; }
486
Clear()487 void Clear() override { args_.clear(); }
488
489 tensorflow::Status Reset(const char* op,
490 const char* raw_device_name) override;
Name()491 const std::string& Name() const override { return op_name_; }
DeviceName()492 const std::string& DeviceName() const override { return device_name_; }
493 tensorflow::Status SetDeviceName(const char* name) override;
494
GetContext()495 tensorflow::ImmediateExecutionContext* GetContext() const override {
496 return context_;
497 }
HasCustomDeviceInput()498 bool HasCustomDeviceInput() const override {
499 return custom_device_tensor_handle_count_ > 0;
500 }
501
502 tensorflow::Status AddInput(tensorflow::AbstractTensorHandle* input) override;
503 tensorflow::Status AddInputList(
504 absl::Span<tensorflow::AbstractTensorHandle* const> inputs) override;
505 tensorflow::Status SetInput(
506 size_t index, tensorflow::ImmediateExecutionTensorHandle* input) override;
507 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
508 const override;
509 tensorflow::Status Execute(
510 absl::Span<tensorflow::AbstractTensorHandle*> retvals,
511 int* num_retvals) override;
OpDef()512 const tensorflow::OpDef* OpDef() const override { return op_def_; }
NodeDef()513 const tensorflow::NodeDef NodeDef() { return fallback_attrs_.BuildNodeDef(); }
514
515 tensorflow::Status SetAttrString(const char* attr_name, const char* data,
516 size_t length) override;
517 tensorflow::Status SetAttrInt(const char* attr_name, int64_t value) override;
518 tensorflow::Status SetAttrFloat(const char* attr_name, float value) override;
519 tensorflow::Status SetAttrBool(const char* attr_name, bool value) override;
520 tensorflow::Status SetAttrType(const char* attr_name,
521 tensorflow::DataType value) override;
522 tensorflow::Status SetAttrShape(const char* attr_name, const int64_t* dims,
523 const int num_dims) override;
524 tensorflow::Status SetAttrFunction(const char* attr_name,
525 const AbstractOperation* value) override;
526 tensorflow::Status SetAttrFunctionName(const char* attr_name,
527 const char* data,
528 size_t length) override;
529 tensorflow::Status SetAttrTensor(
530 const char* attr_name,
531 tensorflow::AbstractTensorInterface* tensor) override;
532 tensorflow::Status SetAttrStringList(const char* attr_name,
533 const void* const* values,
534 const size_t* lengths,
535 int num_values) override;
536 tensorflow::Status SetAttrFloatList(const char* attr_name,
537 const float* values,
538 int num_values) override;
539 tensorflow::Status SetAttrIntList(const char* attr_name,
540 const int64_t* values,
541 int num_values) override;
542 tensorflow::Status SetAttrTypeList(const char* attr_name,
543 const tensorflow::DataType* values,
544 int num_values) override;
545 tensorflow::Status SetAttrBoolList(const char* attr_name,
546 const unsigned char* values,
547 int num_values) override;
548 tensorflow::Status SetAttrShapeList(const char* attr_name,
549 const int64_t** dims, const int* num_dims,
550 int num_values) override;
551 tensorflow::Status SetAttrFunctionList(
552 const char* attr_name,
553 absl::Span<const AbstractOperation*> values) override;
554
555 tensorflow::Status InputLength(const char* input_name, int* length) override;
556 tensorflow::Status OutputLength(const char* output_name,
557 int* length) override;
558
559 const tensorflow::AbstractOpAttrs* GetOpAttrs() const override;
560 void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override;
561
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)562 void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
563 stack_trace_ = stack_trace;
564 }
565
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)566 void SetCancellationManager(
567 tensorflow::CancellationManager* cancellation_manager) override {
568 // TODO(b/181368626): Support cancellation.
569 }
570
GetStackTrace()571 absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
572 return stack_trace_;
573 }
574
575 // Currently not supported.
SetStepId(int64_t step_id)576 void SetStepId(int64_t step_id) override {}
577
578 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)579 static bool classof(const AbstractOperation* ptr) {
580 return ptr->getKind() == kTfrt;
581 }
582
583 friend class OpCache;
584
585 private:
586 // Initialize op_ field. It can be either a trivial op or a composite op.
587 tensorflow::Status Initialize();
588
589 // Note(fishx): This method is copied from current TF. We use it to infer
590 // attribute like "T" in order to run device placement logic from current TF.
591 void MaybeInferInputAttrs();
592
593 // This field holds a primitive op. If the op represents a function, it
594 // will be held by function_state_ below, and this field will be empty.
595 CoreRuntimeOp* op_;
596 RCReference<FunctionState> function_state_;
597 std::string op_name_;
598 // The device user requested to place the op on.
599 std::string device_name_;
600 bool is_function_;
601 tfrt::BefAttrEncoder bef_attr_encoder_;
602 // TODO(b/165412867): Remove AttrBuilder.
603 tensorflow::AttrBuilder fallback_attrs_;
604 const tensorflow::OpDef* op_def_; // op definition from protobuf
605 OpAttrs attrs_;
606 OpAttrsInterface op_attrs_;
607 llvm::SmallVector<
608 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>,
609 8>
610 args_;
611 AbortLocationHandler abort_location_handler_;
612 ContextInterface* const context_;
613 // TODO(kkb): Use tfrt::Location and implement TFRT async stack tracing.
614 absl::optional<tensorflow::ManagedStackTrace> stack_trace_;
615
616 int custom_device_tensor_handle_count_ = 0;
617 };
618
619 } // namespace tf
620 } // namespace tfrt
621
622 #endif // TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
623