• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/c/eager/abstract_op_attrs.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/immediate_execution_context.h"
26 #include "tensorflow/c/eager/immediate_execution_operation.h"
27 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
28 #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
29 #include "tensorflow/c/tensor_interface.h"
30 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
31 #include "tensorflow/core/common_runtime/eager/context.h"
32 #include "tensorflow/core/framework/cancellation.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/tfrt/eager/function_cache.h"
40 #include "tensorflow/core/tfrt/eager/op_cache.h"
41 #include "tensorflow/core/tfrt/eager/tfrt_context.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tfrt/bef_converter/bef_attr_encoder.h"  // from @tf_runtime
44 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
45 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
46 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
47 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
48 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
49 #include "tfrt/host_context/value.h"  // from @tf_runtime
50 #include "tfrt/support/aligned_buffer.h"  // from @tf_runtime
51 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
52 #include "tfrt/support/ref_count.h"  // from @tf_runtime
53 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
54 
55 namespace tfrt {
56 
57 class CoreRuntime;
58 class CoreRuntimeOp;
59 class DenseHostTensor;
60 class OpHandler;
61 class TensorHandle;
62 class TensorMetadata;
63 
64 namespace tf {
65 class EagerOpHandlerSelector;
66 
67 class ContextInterface : public tensorflow::ImmediateExecutionContext {
68  public:
69   ContextInterface(
70       const tensorflow::SessionOptions& opts,
71       tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
72       bool is_async, bool use_tfrt_distributed_runtime);
73   ~ContextInterface() override;
74 
Release()75   void Release() override { delete this; }
76 
77   tensorflow::AbstractTensorInterface* CreateInt64Scalar(
78       int64_t value) override;
79   tensorflow::AbstractTensorInterface* CreateUint64Scalar(
80       uint64_t value) override;
81   tensorflow::AbstractTensorInterface* CreateInt32Scalar(
82       int32_t value) override;
83   tensorflow::AbstractTensorInterface* CreateFloatScalar(float value) override;
84   tensorflow::AbstractTensorInterface* CreateDoubleScalar(
85       double value) override;
86   tensorflow::AbstractTensorInterface* CreateHalfScalar(
87       Eigen::half value) override;
88   tensorflow::AbstractTensorInterface* CreateStringScalar(
89       tensorflow::tstring value) override;
90   tensorflow::AbstractTensorInterface* CreateComplex128Scalar(
91       tensorflow::complex128 value) override;
92   tensorflow::AbstractTensorInterface* CreateBoolScalar(bool value) override;
93 
94   tensorflow::AbstractTensorInterface* CreateTensor(
95       tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) override;
96   tensorflow::AbstractTensorInterface* CreateTensor(
97       tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
98       size_t len, MemoryReleaser memory_releaser,
99       void* memory_releaser_arg) override;
100 
101   tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandle(
102       tensorflow::AbstractTensorInterface* t) override;
103   // Create an abstract tensor handle from tensorflow::Tensor.
104   tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
105       tensorflow::Tensor& t, const char* d_name) override;
106 
107   // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
108   tensorflow::ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
109       tensorflow::ImmediateExecutionTensorHandle* handle) override;
110 
111   tensorflow::ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
112       tensorflow::ImmediateExecutionTensorHandle* handle,
113       const char* device_name, tensorflow::Status* status) override;
114 
115   tensorflow::ImmediateExecutionOperation* CreateOperation() override;
116   tensorflow::Status RegisterFunction(tensorflow::AbstractFunction*) override;
117 
118   tensorflow::CustomDeviceOpHandler& GetCustomDeviceOpHandler() override;
119 
120   tensorflow::Status RegisterCustomDevice(
121       const std::string& name,
122       std::unique_ptr<tensorflow::CustomDevice> device) override;
123 
124   tensorflow::FunctionLibraryDefinition* FuncLibDef() override;
125 
126   void SetReuseRendezvousForFunctions(
127       bool reuse_rendezvous_for_functions) override;
128 
129   void ResetGlobalRendezvousForFunction() override;
130 
131   bool UsesTFRT() override;
132 
133   void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
134 
ListLocalTfDevices()135   std::vector<tensorflow::Device*> ListLocalTfDevices() override {
136     return context_.GetEagerContext()->local_device_mgr()->ListDevices();
137   }
138 
ListAllTfDevices()139   std::vector<tensorflow::Device*> ListAllTfDevices() override {
140     return context_.GetEagerContext()->ListAllTfDevices();
141   }
142 
143   tensorflow::Status AddDevices(
144       std::vector<std::unique_ptr<tensorflow::Device>> devices) override;
145 
146   void ClearCachesAndThreadExecutors() override;
147   void StartStep() override;
148   void EndStep() override;
149 
AsyncWait()150   tensorflow::Status AsyncWait() override {
151     TF_RETURN_IF_ERROR(GetEagerContext()->AsyncWait());
152     GetHostContext()->Quiesce();
153     return ::tensorflow::OkStatus();
154   }
155 
156   tensorflow::Status AddFunctionDef(
157       const tensorflow::FunctionDef& fdef) override;
158   tensorflow::Status AddFunctionDefWithStackTraces(
159       const tensorflow::FunctionDef& fdef,
160       const tensorflow::StackTracesMap& stack_traces) override;
161   std::vector<std::string> ListFunctionNames() override;
162   tensorflow::Status RemoveFunction(const std::string& func) override;
163   const tensorflow::FunctionDef* FindFunctionDef(
164       const std::string& name) const override;
165 
166   const tensorflow::DeviceNameUtils::ParsedName& HostCPUParsedName()
167       const override;
168   const std::string& HostCPUName() const override;
169 
SetAllowSoftPlacement(bool enable)170   void SetAllowSoftPlacement(bool enable) override {
171     // TODO(tfrt-devs): Move this flag to a common place that can be shared
172     // by current TF and TFRT.
173     GetEagerContext()->SetAllowSoftPlacement(enable);
174   }
SetShouldStoreGraphs(bool value)175   void SetShouldStoreGraphs(bool value) override {
176     GetEagerContext()->SetShouldStoreGraphs(value);
177   }
178 
179   tensorflow::Status EnableCollectiveOps(
180       const tensorflow::ServerDef& server_def) override;
181 
182   std::unique_ptr<tensorflow::RunMetadata> ExportRunMetadata() override;
183 
184   // Find the FunctionDef by the given name and record it in RunMetadata.
185   tensorflow::Status RunMetadataRecordFunction(const std::string& func_name);
186 
SetLogDevicePlacement(bool enable)187   void SetLogDevicePlacement(bool enable) override {
188     // TODO(tfrt-devs): Move this flag to a common place that can be shared
189     // by current TF and TFRT.
190     GetEagerContext()->SetLogDevicePlacement(enable);
191   }
192 
SetRunEagerOpAsFunction(bool enable)193   void SetRunEagerOpAsFunction(bool enable) override {
194     // TODO(tfrt-devs): Move this flag to a common place that can be shared
195     // by current TF and TFRT.
196     GetEagerContext()->SetRunEagerOpAsFunction(enable);
197   }
198 
SetJitCompileRewrite(bool enable)199   void SetJitCompileRewrite(bool enable) override {
200     // TODO(tfrt-devs): Move this flag to a common place that can be shared
201     // by current TF and TFRT.
202     GetEagerContext()->SetJitCompileRewrite(enable);
203   }
204 
Executor()205   tensorflow::EagerExecutor& Executor() override {
206     return GetEagerContext()->Executor();
207   }
208   void SetExecutorForThread(tensorflow::EagerExecutor* executor) override;
209 
SetThreadLocalDevicePlacementPolicy(tensorflow::ContextDevicePlacementPolicy policy)210   void SetThreadLocalDevicePlacementPolicy(
211       tensorflow::ContextDevicePlacementPolicy policy) override {
212     // TODO(tfrt-devs): Move this flag to a common place that can be shared
213     // by current TF and TFRT.
214     GetEagerContext()->SetThreadLocalDevicePlacementPolicy(policy);
215   }
GetDevicePlacementPolicy()216   tensorflow::ContextDevicePlacementPolicy GetDevicePlacementPolicy()
217       const override {
218     // TODO(tfrt-devs): Move this flag to a common place that can be shared
219     // by current TF and TFRT.
220     return GetEagerContext()->GetDevicePlacementPolicy();
221   }
222 
223   CoreRuntime* GetCoreRuntime();
224   tensorflow::Status BuildFunctionRequestContext(
225       tensorflow::tfrt_stub::OpKernelRunnerTable* runner_table,
226       RCReference<tfrt::RequestContext>* request_context);
227   tensorflow::Status BuildOpRequestContext(
228       RCReference<tfrt::RequestContext>* request_context);
229   tensorflow::EagerContext* GetEagerContext();
230   const tensorflow::EagerContext* GetEagerContext() const;
231   TfrtContext* GetTfrtContext();
232 
233   // Selects the op handler to execute the op based on the arguments. This
234   // op handler selection is cheap. But it can be nullptr even it return OK
235   // status.
236   tensorflow::Status SelectOpHandlerFromArguments(
237       const tensorflow::ImmediateExecutionOperation& op,
238       OpHandler** op_handler);
239 
240   // Selects the op handler to execute the op based on NodeDef. This op handler
241   // selection is expensive. It will never return nullptr unless there is an
242   // error. Please only invoke this method when the cheap version fails.
243   tensorflow::Status SelectOpHandlerFromNodeDef(
244       const tensorflow::ImmediateExecutionOperation& op,
245       const tensorflow::NodeDef* node_def, OpHandler** op_handler);
246 
247   // Returns the chain for current thread.
248   AsyncValueRef<Chain>* GetChain();
249 
250   // Indicates sync or async execution.
IsAsync()251   bool IsAsync() const { return context_.IsAsync(); }
252 
253   // For LLVM style RTTI.
classof(const AbstractContext * op)254   static bool classof(const AbstractContext* op) {
255     return op->getKind() == kTfrt;
256   }
257 
GetFunctionCache()258   FunctionCache& GetFunctionCache() { return function_cache_; }
259 
GetOpCache()260   OpCache& GetOpCache() { return op_cache_; }
261 
262   OpHandler* GetFallbackOpHandler();
263 
264   std::vector<std::string> GetLoggedOpsTestonly() override;
265 
UseTfrtDistributedRuntime()266   bool UseTfrtDistributedRuntime() { return use_tfrt_distributed_runtime_; }
267 
268 #if !defined(IS_MOBILE_PLATFORM)
SetDistributedManager(std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager> distributed)269   void SetDistributedManager(
270       std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
271           distributed) override {
272     distributed_manager_ = std::move(distributed);
273   }
274 
GetDistributedManager()275   tensorflow::ImmediateExecutionDistributedManager* GetDistributedManager()
276       override {
277     if (use_tfrt_distributed_runtime_) {
278       return distributed_manager_.get();
279     } else {
280       return context_.GetEagerContext()->GetDistributedManager();
281     }
282   }
283 #endif  // !IS_MOBILE_PLATFORM
284 
285  private:
286   HostContext* GetHostContext();
287   ResourceContext* GetResourceContext();
288 
289   Expected<OpHandler*> GetOpHandler(const char* name);
290 
291   TfrtContext context_;
292 
293   mutable tensorflow::mutex chain_map_mu_;
294   // TODO(chuanhao): Hook it up with C API to allow user to manage it.
295   // Each caller thread will have its own chain to dispatch ops.
296   std::unordered_map<std::thread::id, AsyncValueRef<Chain>> thread_local_chain_
297       TF_GUARDED_BY(chain_map_mu_);
298 
299   std::unique_ptr<EagerOpHandlerSelector> op_handler_selector_;
300 
301   // The cache that stores functions (composite CoreRuntimeOps).
302   FunctionCache function_cache_;
303 
304   // The cache that stores CoreRuntimeOps. It's separate from function cache
305   // since a primitive CoreRuntimeOp is essentially a stateless function
306   // pointer, and so it doesn't need ref-count to manage its lifetime.
307   OpCache op_cache_;
308 
309   mutex run_metadata_mu_;
310   std::unique_ptr<tensorflow::RunMetadata> run_metadata_
311       TFRT_GUARDED_BY(run_metadata_mu_);
312 
313   // Use TFRT's implementation of distributed manager.
314   bool use_tfrt_distributed_runtime_ = false;
315 
316   // A distributed manager that helps setup, update, and check liveness of
317   // member tasks in the cluster.
318   std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
319       distributed_manager_;
320 };
321 
322 class TensorInterface : public tensorflow::AbstractTensorInterface {
323  public:
TensorInterface(AsyncValueRef<Tensor> t)324   explicit TensorInterface(AsyncValueRef<Tensor> t) : tensor_(std::move(t)) {}
TensorInterface(tensorflow::Tensor t)325   explicit TensorInterface(tensorflow::Tensor t) : tf_tensor_(std::move(t)) {}
~TensorInterface()326   ~TensorInterface() override {}
327 
Release()328   void Release() override { delete this; }
329 
330   tensorflow::DataType Type() const override;
331   int NumDims() const override;
332   int64_t Dim(int dim_index) const override;
333   int64_t NumElements() const override;
334   size_t ByteSize() const override;
335   void* Data() const override;
336   bool IsAligned() const override;
337   bool CanMove() const override;
IsTfTensor()338   bool IsTfTensor() const { return !tensor_; }
339   std::string SummarizeValue() const override;
340 
341   AsyncValueRef<tfrt::Tensor> TensorRef() const;
TfTensor()342   tensorflow::Tensor& TfTensor() { return tf_tensor_; }
343 
344  private:
345   AsyncValueRef<tfrt::Tensor> tensor_;
346   // NOTE(b/167608876): tensorflow::Tensor for handling non-scalar string
347   // tensors, for backward compatibility. This is a temporary workaround until
348   // we find a proper way to unify tensorflow::tstring and
349   // tfrt::StringHostTensor.
350   tensorflow::Tensor tf_tensor_;
351 };
352 
353 class TensorHandleInterface
354     : public tensorflow::ImmediateExecutionTensorHandle {
355  public:
356   explicit TensorHandleInterface(Value&& v, TfrtContext* context);
357 
358   explicit TensorHandleInterface(tensorflow::DataType dtype, Value&& v,
359                                  TfrtContext* context);
360 
Release()361   void Release() override { Unref(); }
362 
363   tensorflow::DataType DataType() const override;
364   tensorflow::Status TensorHandleStatus() const override;
365   tensorflow::Status Shape(
366       tensorflow::PartialTensorShape* shape) const override;
367   tensorflow::Status NumDims(int* num_dims) const override;
368   tensorflow::Status NumElements(int64_t* num_elements) const override;
369   tensorflow::Status Dim(int dim_index, int64_t* dim) const override;
370 
371   // DeviceName represents the device that creates the tensor handle.
372   // Currently the same with BackingDeviceName.
373   // TODO(b/169341326): unify device behavior between current TF and TFRT.
374   const char* DeviceName(tensorflow::Status* status) const override;
375 
376   // BackingDeviceName represents the device where the tensor is physically
377   // placed. DeviceName and BackingDeviceName are the same for TFRT.
378   const char* BackingDeviceName(tensorflow::Status* status) const override;
379 
380   const char* DeviceType(tensorflow::Status* status) const override;
381 
DeviceId(tensorflow::Status * status)382   int DeviceId(tensorflow::Status* status) const override {
383     // TODO(tfrt-devs): implement for tfrt tensor handle.
384     llvm_unreachable("unimplemented method.");
385   }
386 
387   tensorflow::AbstractTensorInterface* Resolve(
388       tensorflow::Status* status) override;
389 
390   // TODO(b/161897666): Figure out if we can get rid of returning a new
391   // pointer here and just use Ref().
Copy()392   tensorflow::ImmediateExecutionTensorHandle* Copy() override {
393     Ref();
394     return this;
395   }
396 
Handle()397   TensorHandle Handle() { return value_.get<TensorHandle>().CopyRef(); }
398 
value()399   Value* value() { return &value_; }
400 
401   // For LLVM style RTTI.
classof(const tensorflow::AbstractTensorHandle * ptr)402   static bool classof(const tensorflow::AbstractTensorHandle* ptr) {
403     return ptr->getKind() == kTfrt;
404   }
405 
406  private:
407   llvm::Optional<const TensorMetadata*> Metadata() const;
408 
409   tensorflow::StatusOr<tensorflow::DataType> ObtainDataTypeFromMetaData(
410       const TensorMetadata*) const;
411 
412   // If the tensor handle is generated as the result of a function, the datatype
413   // is known from the function output signature.
414   // Therefore, we can obtain the datatype earlier, before the function
415   // execution completes.
416   llvm::Optional<tensorflow::DataType> dtype_;
417 
418   TfrtContext& context_;
419 
420   // Value of tfrt::TensorHandle.
421   Value value_;
422 };
423 
424 template <typename T>
TensorHandleFromInterface(T * handle)425 inline TensorHandleInterface* TensorHandleFromInterface(T* handle) {
426   return tensorflow::down_cast<TensorHandleInterface*>(handle);
427 }
428 
429 // TFRT location handler class that simply prints the error and abort the
430 // program on encountering any error. It's primarily for easy debugging
431 // TODO(kkb): Handle errors probably by raising a Python exception.
432 class AbortLocationHandler final : public tfrt::LocationHandler {
433  public:
434   tfrt::Location GetCurrentLocation();
435 
436  private:
DecodeLocation(tfrt::Location loc)437   tfrt::DecodedLocation DecodeLocation(tfrt::Location loc) const override {
438     // Return a dummy decoded location.
439     return {};
440   }
441 };
442 
443 class OpAttrsInterface : public tensorflow::AbstractOpAttrs {
444  public:
OpAttrsInterface(const OpAttrs * attrs,tensorflow::AttrBuilder * fallback_attrs)445   explicit OpAttrsInterface(const OpAttrs* attrs,
446                             tensorflow::AttrBuilder* fallback_attrs)
447       : AbstractOpAttrs(
448             tensorflow::AbstractOpAttrs::AbstractOpAttrsKind::kTfrt),
449         attrs_(attrs),
450         fallback_attrs_(fallback_attrs) {}
~OpAttrsInterface()451   ~OpAttrsInterface() override {}
452 
453   void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
454   tensorflow::Status GetTypeList(
455       absl::string_view attr_name,
456       absl::InlinedVector<tensorflow::DataType, 4>* type_list) const override;
457 
458   bool GetInt(absl::string_view attr_name, int64_t* result) const override;
459   bool GetFloat(absl::string_view attr_name, float* result) const override;
460   bool GetBool(absl::string_view attr_name, bool* result) const override;
461   bool GetType(absl::string_view attr_name,
462                tensorflow::DataType* result) const override;
463 
GetAttrs()464   const OpAttrs* GetAttrs() const { return attrs_; }
465 
GetFallbackAttrs()466   const tensorflow::AttrBuilder* GetFallbackAttrs() const {
467     return fallback_attrs_;
468   }
469 
470  private:
471   // TODO(fishx): Move ownership to here.
472   const OpAttrs* attrs_;
473 
474   // TODO(tfrt-devs): Remove this field and generate NameAttrList from attrs_.
475   // Today it is fine since we will set both attrs and fallback_attrs.
476   const tensorflow::AttrBuilder* fallback_attrs_;
477 };
478 
479 class OperationInterface : public tensorflow::ImmediateExecutionOperation {
480  public:
481   // All arguments come from ContextInterface.
482   explicit OperationInterface(ContextInterface* context);
~OperationInterface()483   ~OperationInterface() override {}
484 
Release()485   void Release() override { delete this; }
486 
Clear()487   void Clear() override { args_.clear(); }
488 
489   tensorflow::Status Reset(const char* op,
490                            const char* raw_device_name) override;
Name()491   const std::string& Name() const override { return op_name_; }
DeviceName()492   const std::string& DeviceName() const override { return device_name_; }
493   tensorflow::Status SetDeviceName(const char* name) override;
494 
GetContext()495   tensorflow::ImmediateExecutionContext* GetContext() const override {
496     return context_;
497   }
HasCustomDeviceInput()498   bool HasCustomDeviceInput() const override {
499     return custom_device_tensor_handle_count_ > 0;
500   }
501 
502   tensorflow::Status AddInput(tensorflow::AbstractTensorHandle* input) override;
503   tensorflow::Status AddInputList(
504       absl::Span<tensorflow::AbstractTensorHandle* const> inputs) override;
505   tensorflow::Status SetInput(
506       size_t index, tensorflow::ImmediateExecutionTensorHandle* input) override;
507   absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
508       const override;
509   tensorflow::Status Execute(
510       absl::Span<tensorflow::AbstractTensorHandle*> retvals,
511       int* num_retvals) override;
OpDef()512   const tensorflow::OpDef* OpDef() const override { return op_def_; }
NodeDef()513   const tensorflow::NodeDef NodeDef() { return fallback_attrs_.BuildNodeDef(); }
514 
515   tensorflow::Status SetAttrString(const char* attr_name, const char* data,
516                                    size_t length) override;
517   tensorflow::Status SetAttrInt(const char* attr_name, int64_t value) override;
518   tensorflow::Status SetAttrFloat(const char* attr_name, float value) override;
519   tensorflow::Status SetAttrBool(const char* attr_name, bool value) override;
520   tensorflow::Status SetAttrType(const char* attr_name,
521                                  tensorflow::DataType value) override;
522   tensorflow::Status SetAttrShape(const char* attr_name, const int64_t* dims,
523                                   const int num_dims) override;
524   tensorflow::Status SetAttrFunction(const char* attr_name,
525                                      const AbstractOperation* value) override;
526   tensorflow::Status SetAttrFunctionName(const char* attr_name,
527                                          const char* data,
528                                          size_t length) override;
529   tensorflow::Status SetAttrTensor(
530       const char* attr_name,
531       tensorflow::AbstractTensorInterface* tensor) override;
532   tensorflow::Status SetAttrStringList(const char* attr_name,
533                                        const void* const* values,
534                                        const size_t* lengths,
535                                        int num_values) override;
536   tensorflow::Status SetAttrFloatList(const char* attr_name,
537                                       const float* values,
538                                       int num_values) override;
539   tensorflow::Status SetAttrIntList(const char* attr_name,
540                                     const int64_t* values,
541                                     int num_values) override;
542   tensorflow::Status SetAttrTypeList(const char* attr_name,
543                                      const tensorflow::DataType* values,
544                                      int num_values) override;
545   tensorflow::Status SetAttrBoolList(const char* attr_name,
546                                      const unsigned char* values,
547                                      int num_values) override;
548   tensorflow::Status SetAttrShapeList(const char* attr_name,
549                                       const int64_t** dims, const int* num_dims,
550                                       int num_values) override;
551   tensorflow::Status SetAttrFunctionList(
552       const char* attr_name,
553       absl::Span<const AbstractOperation*> values) override;
554 
555   tensorflow::Status InputLength(const char* input_name, int* length) override;
556   tensorflow::Status OutputLength(const char* output_name,
557                                   int* length) override;
558 
559   const tensorflow::AbstractOpAttrs* GetOpAttrs() const override;
560   void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override;
561 
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)562   void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
563     stack_trace_ = stack_trace;
564   }
565 
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)566   void SetCancellationManager(
567       tensorflow::CancellationManager* cancellation_manager) override {
568     // TODO(b/181368626): Support cancellation.
569   }
570 
GetStackTrace()571   absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
572     return stack_trace_;
573   }
574 
575   // Currently not supported.
SetStepId(int64_t step_id)576   void SetStepId(int64_t step_id) override {}
577 
578   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)579   static bool classof(const AbstractOperation* ptr) {
580     return ptr->getKind() == kTfrt;
581   }
582 
583   friend class OpCache;
584 
585  private:
586   // Initialize op_ field. It can be either a trivial op or a composite op.
587   tensorflow::Status Initialize();
588 
589   // Note(fishx): This method is copied from current TF. We use it to infer
590   // attribute like "T" in order to run device placement logic from current TF.
591   void MaybeInferInputAttrs();
592 
593   // This field holds a primitive op. If the op represents a function, it
594   // will be held by function_state_ below, and this field will be empty.
595   CoreRuntimeOp* op_;
596   RCReference<FunctionState> function_state_;
597   std::string op_name_;
598   // The device user requested to place the op on.
599   std::string device_name_;
600   bool is_function_;
601   tfrt::BefAttrEncoder bef_attr_encoder_;
602   // TODO(b/165412867): Remove AttrBuilder.
603   tensorflow::AttrBuilder fallback_attrs_;
604   const tensorflow::OpDef* op_def_;  // op definition from protobuf
605   OpAttrs attrs_;
606   OpAttrsInterface op_attrs_;
607   llvm::SmallVector<
608       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>,
609       8>
610       args_;
611   AbortLocationHandler abort_location_handler_;
612   ContextInterface* const context_;
613   // TODO(kkb): Use tfrt::Location and implement TFRT async stack tracing.
614   absl::optional<tensorflow::ManagedStackTrace> stack_trace_;
615 
616   int custom_device_tensor_handle_count_ = 0;
617 };
618 
619 }  // namespace tf
620 }  // namespace tfrt
621 
622 #endif  // TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
623