• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/c/eager/abstract_op_attrs.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/immediate_execution_context.h"
26 #include "tensorflow/c/eager/immediate_execution_operation.h"
27 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
28 #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
29 #include "tensorflow/c/tensor_interface.h"
30 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
31 #include "tensorflow/core/common_runtime/eager/context.h"
32 #include "tensorflow/core/framework/cancellation.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/tfrt/eager/function_cache.h"
40 #include "tensorflow/core/tfrt/eager/op_cache.h"
41 #include "tensorflow/core/tfrt/eager/tfrt_context.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tfrt/bef_converter/bef_attr_encoder.h"  // from @tf_runtime
44 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
45 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
46 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
47 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
48 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
49 #include "tfrt/host_context/value.h"  // from @tf_runtime
50 #include "tfrt/support/aligned_buffer.h"  // from @tf_runtime
51 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
52 #include "tfrt/support/ref_count.h"  // from @tf_runtime
53 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
54 
55 namespace tfrt {
56 
57 class CoreRuntime;
58 class CoreRuntimeOp;
59 class DenseHostTensor;
60 class OpHandler;
61 class TensorHandle;
62 class TensorMetadata;
63 
64 namespace tf {
65 class EagerOpHandlerSelector;
66 
67 class ContextInterface : public tensorflow::ImmediateExecutionContext {
68  public:
69   ContextInterface(
70       const tensorflow::SessionOptions& opts,
71       tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
72       bool is_async, bool use_tfrt_distributed_runtime);
73   ~ContextInterface() override;
74 
Release()75   void Release() override { delete this; }
76 
77   tensorflow::AbstractTensorInterface* CreateInt64Scalar(
78       int64_t value) override;
79   tensorflow::AbstractTensorInterface* CreateUint64Scalar(
80       uint64_t value) override;
81   tensorflow::AbstractTensorInterface* CreateInt32Scalar(
82       int32_t value) override;
83   tensorflow::AbstractTensorInterface* CreateFloatScalar(float value) override;
84   tensorflow::AbstractTensorInterface* CreateDoubleScalar(
85       double value) override;
86   tensorflow::AbstractTensorInterface* CreateHalfScalar(
87       Eigen::half value) override;
88   tensorflow::AbstractTensorInterface* CreateStringScalar(
89       tensorflow::tstring value) override;
90   tensorflow::AbstractTensorInterface* CreateComplex128Scalar(
91       tensorflow::complex128 value) override;
92   tensorflow::AbstractTensorInterface* CreateBoolScalar(bool value) override;
93 
94   tensorflow::AbstractTensorInterface* CreateTensor(
95       tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) override;
96   tensorflow::AbstractTensorInterface* CreateTensor(
97       tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
98       size_t len, MemoryReleaser memory_releaser,
99       void* memory_releaser_arg) override;
100 
101   tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandle(
102       tensorflow::AbstractTensorInterface* t) override;
103   // Create an abstract tensor handle from tensorflow::Tensor.
104   tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
105       tensorflow::Tensor& t, const char* d_name) override;
106 
107   // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
108   tensorflow::ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
109       tensorflow::ImmediateExecutionTensorHandle* handle) override;
110 
111   tensorflow::ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
112       tensorflow::ImmediateExecutionTensorHandle* handle,
113       const char* device_name, tensorflow::Status* status) override;
114 
115   tensorflow::ImmediateExecutionOperation* CreateOperation() override;
116   tensorflow::Status RegisterFunction(tensorflow::AbstractFunction*) override;
117 
118   tensorflow::CustomDeviceOpHandler& GetCustomDeviceOpHandler() override;
119 
120   tensorflow::Status RegisterCustomDevice(
121       const std::string& name,
122       std::unique_ptr<tensorflow::CustomDevice> device) override;
123 
124   tensorflow::FunctionLibraryDefinition* FuncLibDef() override;
125 
126   void SetReuseRendezvousForFunctions(
127       bool reuse_rendezvous_for_functions) override;
128 
129   void ResetGlobalRendezvousForFunction() override;
130 
131   bool UsesTFRT() override;
132 
133   void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
134 
ListLocalTfDevices()135   std::vector<tensorflow::Device*> ListLocalTfDevices() override {
136     return context_.GetEagerContext()->local_device_mgr()->ListDevices();
137   }
138 
139   tensorflow::Status AddDevices(
140       std::vector<std::unique_ptr<tensorflow::Device>> devices) override;
141 
142   void ClearCachesAndThreadExecutors() override;
143   void StartStep() override;
144   void EndStep() override;
145 
AsyncWait()146   tensorflow::Status AsyncWait() override {
147     GetHostContext()->Quiesce();
148     return tensorflow::Status::OK();
149   }
150 
151   tensorflow::Status AddFunctionDef(
152       const tensorflow::FunctionDef& fdef) override;
153   tensorflow::Status AddFunctionDefWithStackTraces(
154       const tensorflow::FunctionDef& fdef,
155       const tensorflow::StackTracesMap& stack_traces) override;
156   std::vector<std::string> ListFunctionNames() override;
157   tensorflow::Status RemoveFunction(const std::string& func) override;
158   const tensorflow::FunctionDef* FindFunctionDef(
159       const std::string& name) const override;
160 
161   const tensorflow::DeviceNameUtils::ParsedName& HostCPUParsedName()
162       const override;
163   const std::string& HostCPUName() const override;
164 
SetAllowSoftPlacement(bool enable)165   void SetAllowSoftPlacement(bool enable) override {
166     // TODO(tfrt-devs): Move this flag to a common place that can be shared
167     // by current TF and TFRT.
168     GetEagerContext()->SetAllowSoftPlacement(enable);
169   }
SetShouldStoreGraphs(bool value)170   void SetShouldStoreGraphs(bool value) override {
171     GetEagerContext()->SetShouldStoreGraphs(value);
172   }
173 
174   tensorflow::Status EnableCollectiveOps(
175       const tensorflow::ServerDef& server_def) override;
176 
177   std::unique_ptr<tensorflow::RunMetadata> ExportRunMetadata() override;
178 
179   // Find the FunctionDef by the given name and record it in RunMetadata.
180   tensorflow::Status RunMetadataRecordFunction(const std::string& func_name);
181 
SetLogDevicePlacement(bool enable)182   void SetLogDevicePlacement(bool enable) override {
183     // TODO(tfrt-devs): Move this flag to a common place that can be shared
184     // by current TF and TFRT.
185     GetEagerContext()->SetLogDevicePlacement(enable);
186   }
187 
Executor()188   tensorflow::EagerExecutor& Executor() override {
189     return GetEagerContext()->Executor();
190   }
191   void SetExecutorForThread(tensorflow::EagerExecutor* executor) override;
192 
SetThreadLocalDevicePlacementPolicy(tensorflow::ContextDevicePlacementPolicy policy)193   void SetThreadLocalDevicePlacementPolicy(
194       tensorflow::ContextDevicePlacementPolicy policy) override {
195     // TODO(tfrt-devs): Move this flag to a common place that can be shared
196     // by current TF and TFRT.
197     GetEagerContext()->SetThreadLocalDevicePlacementPolicy(policy);
198   }
GetDevicePlacementPolicy()199   tensorflow::ContextDevicePlacementPolicy GetDevicePlacementPolicy()
200       const override {
201     // TODO(tfrt-devs): Move this flag to a common place that can be shared
202     // by current TF and TFRT.
203     return GetEagerContext()->GetDevicePlacementPolicy();
204   }
205 
206   CoreRuntime* GetCoreRuntime();
207   tensorflow::Status BuildFunctionRequestContext(
208       tensorflow::tfd::OpKernelRunnerTable* runner_table,
209       RCReference<tfrt::RequestContext>* request_context);
210   tensorflow::Status BuildOpRequestContext(
211       RCReference<tfrt::RequestContext>* request_context);
212   tensorflow::EagerContext* GetEagerContext();
213   const tensorflow::EagerContext* GetEagerContext() const;
214 
215   // Selects the op handler to execute the op based on the arguments. This
216   // op handler selection is cheap. But it can be nullptr even it return OK
217   // status.
218   tensorflow::Status SelectOpHandlerFromArguments(
219       const tensorflow::ImmediateExecutionOperation& op,
220       OpHandler** op_handler);
221 
222   // Selects the op handler to execute the op based on NodeDef. This op handler
223   // selection is expensive. It will never return nullptr unless there is an
224   // error. Please only invoke this method when the cheap version fails.
225   tensorflow::Status SelectOpHandlerFromNodeDef(
226       const tensorflow::ImmediateExecutionOperation& op,
227       const tensorflow::NodeDef* node_def, OpHandler** op_handler);
228 
229   // Returns the chain for current thread.
230   AsyncValueRef<Chain>* GetChain();
231 
232   // Indicates sync or async execution.
is_async()233   bool is_async() { return GetEagerContext()->Executor().Async(); }
234 
235   // For LLVM style RTTI.
classof(const AbstractContext * op)236   static bool classof(const AbstractContext* op) {
237     return op->getKind() == kTfrt;
238   }
239 
GetFunctionCache()240   FunctionCache& GetFunctionCache() { return function_cache_; }
241 
GetOpCache()242   OpCache& GetOpCache() { return op_cache_; }
243 
244   OpHandler* GetFallbackOpHandler();
245 
246   std::vector<std::string> GetLoggedOpsTestonly() override;
247 
UseTfrtDistributedRuntime()248   bool UseTfrtDistributedRuntime() { return use_tfrt_distributed_runtime_; }
249 
250 #if !defined(IS_MOBILE_PLATFORM)
SetDistributedManager(std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager> distributed)251   void SetDistributedManager(
252       std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
253           distributed) override {
254     distributed_manager_ = std::move(distributed);
255   }
256 
GetDistributedManager()257   tensorflow::ImmediateExecutionDistributedManager* GetDistributedManager()
258       override {
259     if (use_tfrt_distributed_runtime_) {
260       return distributed_manager_.get();
261     } else {
262       return context_.GetEagerContext()->GetDistributedManager();
263     }
264   }
265 #endif  // !IS_MOBILE_PLATFORM
266 
267  private:
268   HostContext* GetHostContext();
269   ResourceContext* GetResourceContext();
270 
271   Expected<OpHandler*> GetOpHandler(const char* name);
272 
273   TfrtContext context_;
274 
275   mutable tensorflow::mutex chain_map_mu_;
276   // TODO(chuanhao): Hook it up with C API to allow user to manage it.
277   // Each caller thread will have its own chain to dispatch ops.
278   std::unordered_map<std::thread::id, AsyncValueRef<Chain>> thread_local_chain_
279       TF_GUARDED_BY(chain_map_mu_);
280 
281   std::unique_ptr<EagerOpHandlerSelector> op_handler_selector_;
282 
283   // The cache that stores functions (composite CoreRuntimeOps).
284   FunctionCache function_cache_;
285 
286   // The cache that stores CoreRuntimeOps. It's separate from function cache
287   // since a primitive CoreRuntimeOp is essentially a stateless function
288   // pointer, and so it doesn't need ref-count to manage its lifetime.
289   OpCache op_cache_;
290 
291   mutex run_metadata_mu_;
292   std::unique_ptr<tensorflow::RunMetadata> run_metadata_
293       TFRT_GUARDED_BY(run_metadata_mu_);
294 
295   // Use TFRT's implementation of distributed manager.
296   bool use_tfrt_distributed_runtime_ = false;
297 
298   // A distributed manager that helps setup, update, and check liveness of
299   // member tasks in the cluster.
300   std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
301       distributed_manager_;
302 };
303 
304 class TensorInterface : public tensorflow::AbstractTensorInterface {
305  public:
TensorInterface(AsyncValueRef<Tensor> t)306   explicit TensorInterface(AsyncValueRef<Tensor> t) : tensor_(std::move(t)) {}
TensorInterface(tensorflow::Tensor t)307   explicit TensorInterface(tensorflow::Tensor t) : tf_tensor_(std::move(t)) {}
~TensorInterface()308   ~TensorInterface() override {}
309 
Release()310   void Release() override { delete this; }
311 
312   tensorflow::DataType Type() const override;
313   int NumDims() const override;
314   int64_t Dim(int dim_index) const override;
315   int64_t NumElements() const override;
316   size_t ByteSize() const override;
317   void* Data() const override;
318   bool IsAligned() const override;
319   bool CanMove() const override;
IsTfTensor()320   bool IsTfTensor() const { return !tensor_; }
321   std::string SummarizeValue() const override;
322 
323   AsyncValueRef<tfrt::Tensor> TensorRef() const;
TfTensor()324   tensorflow::Tensor& TfTensor() { return tf_tensor_; }
325 
326  private:
327   AsyncValueRef<tfrt::Tensor> tensor_;
328   // NOTE(b/167608876): tensorflow::Tensor for handling non-scalar string
329   // tensors, for backward compatibility. This is a temporary workaround until
330   // we find a proper way to unify tensorflow::tstring and
331   // tfrt::StringHostTensor.
332   tensorflow::Tensor tf_tensor_;
333 };
334 
335 class TensorHandleInterface
336     : public tensorflow::ImmediateExecutionTensorHandle {
337  public:
338   explicit TensorHandleInterface(Value&& v, CoreRuntime* corert);
339 
Release()340   void Release() override { Unref(); }
341 
342   tensorflow::DataType DataType() const override;
343   tensorflow::Status Shape(
344       tensorflow::PartialTensorShape* shape) const override;
345   tensorflow::Status NumDims(int* num_dims) const override;
346   tensorflow::Status NumElements(int64_t* num_elements) const override;
347   tensorflow::Status Dim(int dim_index, int64_t* dim) const override;
348 
349   // DeviceName represents the device that creates the tensor handle.
350   // Currently the same with BackingDeviceName.
351   // TODO(b/169341326): unify device behavior between current TF and TFRT.
352   const char* DeviceName(tensorflow::Status* status) const override;
353 
354   // BackingDeviceName represents the device where the tensor is physically
355   // placed. DeviceName and BackingDeviceName are the same for TFRT.
356   const char* BackingDeviceName(tensorflow::Status* status) const override;
357 
358   const char* DeviceType(tensorflow::Status* status) const override;
359 
DeviceId(tensorflow::Status * status)360   int DeviceId(tensorflow::Status* status) const override {
361     // TODO(tfrt-devs): implement for tfrt tensor handle.
362     llvm_unreachable("unimplemented method.");
363   }
364 
365   tensorflow::AbstractTensorInterface* Resolve(
366       tensorflow::Status* status) override;
367 
368   // TODO(b/161897666): Figure out if we can get rid of returning a new
369   // pointer here and just use Ref().
Copy()370   tensorflow::ImmediateExecutionTensorHandle* Copy() override {
371     Ref();
372     return this;
373   }
374 
Handle()375   TensorHandle Handle() { return value_.get<TensorHandle>().CopyRef(); }
376 
value()377   Value* value() { return &value_; }
378 
379   // For LLVM style RTTI.
classof(const tensorflow::AbstractTensorHandle * ptr)380   static bool classof(const tensorflow::AbstractTensorHandle* ptr) {
381     return ptr->getKind() == kTfrt;
382   }
383 
384  private:
385   llvm::Optional<const TensorMetadata*> Metadata() const;
386 
387   CoreRuntime& corert_;
388 
389   // Value of tfrt::TensorHandle.
390   Value value_;
391 };
392 
393 template <typename T>
TensorHandleFromInterface(T * handle)394 inline TensorHandleInterface* TensorHandleFromInterface(T* handle) {
395   return tensorflow::down_cast<TensorHandleInterface*>(handle);
396 }
397 
398 // TFRT location handler class that simply prints the error and abort the
399 // program on encountering any error. It's primarily for easy debugging
400 // TODO(kkb): Handle errors probably by raising a Python exception.
401 class AbortLocationHandler final : public tfrt::LocationHandler {
402  public:
403   tfrt::Location GetCurrentLocation();
404 
405  private:
DecodeLocation(tfrt::Location loc)406   tfrt::DecodedLocation DecodeLocation(tfrt::Location loc) const override {
407     // Return a dummy decoded location.
408     return {};
409   }
410 };
411 
412 class OpAttrsInterface : public tensorflow::AbstractOpAttrs {
413  public:
OpAttrsInterface(const OpAttrs * attrs,tensorflow::AttrBuilder * fallback_attrs)414   explicit OpAttrsInterface(const OpAttrs* attrs,
415                             tensorflow::AttrBuilder* fallback_attrs)
416       : AbstractOpAttrs(
417             tensorflow::AbstractOpAttrs::AbstractOpAttrsKind::kTfrt),
418         attrs_(attrs),
419         fallback_attrs_(fallback_attrs) {}
~OpAttrsInterface()420   ~OpAttrsInterface() override {}
421 
422   void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
423   tensorflow::Status GetTypeList(
424       absl::string_view attr_name,
425       absl::InlinedVector<tensorflow::DataType, 4>* type_list) const override;
426 
427   bool GetInt(absl::string_view attr_name, int64_t* result) const override;
428   bool GetFloat(absl::string_view attr_name, float* result) const override;
429   bool GetBool(absl::string_view attr_name, bool* result) const override;
430   bool GetType(absl::string_view attr_name,
431                tensorflow::DataType* result) const override;
432 
GetAttrs()433   const OpAttrs* GetAttrs() const { return attrs_; }
434 
GetFallbackAttrs()435   const tensorflow::AttrBuilder* GetFallbackAttrs() const {
436     return fallback_attrs_;
437   }
438 
439  private:
440   // TODO(fishx): Move ownership to here.
441   const OpAttrs* attrs_;
442 
443   // TODO(tfrt-devs): Remove this field and generate NameAttrList from attrs_.
444   // Today it is fine since we will set both attrs and fallback_attrs.
445   const tensorflow::AttrBuilder* fallback_attrs_;
446 };
447 
448 class OperationInterface : public tensorflow::ImmediateExecutionOperation {
449  public:
450   // All arguments come from ContextInterface.
451   explicit OperationInterface(ContextInterface* context);
~OperationInterface()452   ~OperationInterface() override {}
453 
Release()454   void Release() override { delete this; }
455 
Clear()456   void Clear() override { args_.clear(); }
457 
458   tensorflow::Status Reset(const char* op,
459                            const char* raw_device_name) override;
Name()460   const std::string& Name() const override { return op_name_; }
DeviceName()461   const std::string& DeviceName() const override { return device_name_; }
462   tensorflow::Status SetDeviceName(const char* name) override;
463 
GetContext()464   tensorflow::ImmediateExecutionContext* GetContext() const override {
465     return context_;
466   }
HasCustomDeviceInput()467   bool HasCustomDeviceInput() const override {
468     return custom_device_tensor_handle_count_ > 0;
469   }
470 
471   tensorflow::Status AddInput(tensorflow::AbstractTensorHandle* input) override;
472   tensorflow::Status AddInputList(
473       absl::Span<tensorflow::AbstractTensorHandle* const> inputs) override;
474   tensorflow::Status SetInput(
475       size_t index, tensorflow::ImmediateExecutionTensorHandle* input) override;
476   absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
477       const override;
478   tensorflow::Status Execute(
479       absl::Span<tensorflow::AbstractTensorHandle*> retvals,
480       int* num_retvals) override;
OpDef()481   const tensorflow::OpDef* OpDef() const override { return op_def_; }
NodeDef()482   const tensorflow::NodeDef NodeDef() { return fallback_attrs_.BuildNodeDef(); }
483 
484   tensorflow::Status SetAttrString(const char* attr_name, const char* data,
485                                    size_t length) override;
486   tensorflow::Status SetAttrInt(const char* attr_name, int64_t value) override;
487   tensorflow::Status SetAttrFloat(const char* attr_name, float value) override;
488   tensorflow::Status SetAttrBool(const char* attr_name, bool value) override;
489   tensorflow::Status SetAttrType(const char* attr_name,
490                                  tensorflow::DataType value) override;
491   tensorflow::Status SetAttrShape(const char* attr_name, const int64_t* dims,
492                                   const int num_dims) override;
493   tensorflow::Status SetAttrFunction(const char* attr_name,
494                                      const AbstractOperation* value) override;
495   tensorflow::Status SetAttrFunctionName(const char* attr_name,
496                                          const char* data,
497                                          size_t length) override;
498   tensorflow::Status SetAttrTensor(
499       const char* attr_name,
500       tensorflow::AbstractTensorInterface* tensor) override;
501   tensorflow::Status SetAttrStringList(const char* attr_name,
502                                        const void* const* values,
503                                        const size_t* lengths,
504                                        int num_values) override;
505   tensorflow::Status SetAttrFloatList(const char* attr_name,
506                                       const float* values,
507                                       int num_values) override;
508   tensorflow::Status SetAttrIntList(const char* attr_name,
509                                     const int64_t* values,
510                                     int num_values) override;
511   tensorflow::Status SetAttrTypeList(const char* attr_name,
512                                      const tensorflow::DataType* values,
513                                      int num_values) override;
514   tensorflow::Status SetAttrBoolList(const char* attr_name,
515                                      const unsigned char* values,
516                                      int num_values) override;
517   tensorflow::Status SetAttrShapeList(const char* attr_name,
518                                       const int64_t** dims, const int* num_dims,
519                                       int num_values) override;
520   tensorflow::Status SetAttrFunctionList(
521       const char* attr_name,
522       absl::Span<const AbstractOperation*> values) override;
523 
524   tensorflow::Status InputLength(const char* input_name, int* length) override;
525   tensorflow::Status OutputLength(const char* output_name,
526                                   int* length) override;
527 
528   const tensorflow::AbstractOpAttrs* GetOpAttrs() const override;
529   void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override;
530 
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)531   void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
532     stack_trace_ = stack_trace;
533   }
534 
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)535   void SetCancellationManager(
536       tensorflow::CancellationManager* cancellation_manager) override {
537     // TODO(b/181368626): Support cancellation.
538   }
539 
GetStackTrace()540   absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
541     return stack_trace_;
542   }
543 
544   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)545   static bool classof(const AbstractOperation* ptr) {
546     return ptr->getKind() == kTfrt;
547   }
548 
549   friend class OpCache;
550 
551  private:
552   // Initialize op_ field. It can be either a trivial op or a composite op.
553   tensorflow::Status Initialize();
554 
555   // Note(fishx): This method is copied from current TF. We use it to infer
556   // attribute like "T" in order to run device placement logic from current TF.
557   void MaybeInferInputAttrs();
558 
559   // This field holds a primitive op. If the op represents a function, it
560   // will be held by function_state_ below, and this field will be empty.
561   CoreRuntimeOp* op_;
562   RCReference<FunctionState> function_state_;
563   std::string op_name_;
564   // The device user requested to place the op on.
565   std::string device_name_;
566   bool is_function_;
567   tfrt::BefAttrEncoder bef_attr_encoder_;
568   // TODO(b/165412867): Remove AttrBuilder.
569   tensorflow::AttrBuilder fallback_attrs_;
570   const tensorflow::OpDef* op_def_;  // op definition from protobuf
571   OpAttrs attrs_;
572   OpAttrsInterface op_attrs_;
573   SmallVector<
574       tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>,
575       8>
576       args_;
577   AbortLocationHandler abort_location_handler_;
578   ContextInterface* const context_;
579   // TODO(kkb): Use tfrt::Location and implement TFRT async stack tracing.
580   absl::optional<tensorflow::ManagedStackTrace> stack_trace_;
581 
582   int custom_device_tensor_handle_count_ = 0;
583 };
584 
585 }  // namespace tf
586 }  // namespace tfrt
587 
588 #endif  // TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
589