1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/types/optional.h"
23 #include "tensorflow/c/eager/abstract_op_attrs.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/immediate_execution_context.h"
26 #include "tensorflow/c/eager/immediate_execution_operation.h"
27 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
28 #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
29 #include "tensorflow/c/tensor_interface.h"
30 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
31 #include "tensorflow/core/common_runtime/eager/context.h"
32 #include "tensorflow/core/framework/cancellation.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/tfrt/eager/function_cache.h"
40 #include "tensorflow/core/tfrt/eager/op_cache.h"
41 #include "tensorflow/core/tfrt/eager/tfrt_context.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tfrt/bef_converter/bef_attr_encoder.h" // from @tf_runtime
44 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
45 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
46 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
47 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
48 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
49 #include "tfrt/host_context/value.h" // from @tf_runtime
50 #include "tfrt/support/aligned_buffer.h" // from @tf_runtime
51 #include "tfrt/support/forward_decls.h" // from @tf_runtime
52 #include "tfrt/support/ref_count.h" // from @tf_runtime
53 #include "tfrt/tensor/tensor.h" // from @tf_runtime
54
55 namespace tfrt {
56
57 class CoreRuntime;
58 class CoreRuntimeOp;
59 class DenseHostTensor;
60 class OpHandler;
61 class TensorHandle;
62 class TensorMetadata;
63
64 namespace tf {
65 class EagerOpHandlerSelector;
66
67 class ContextInterface : public tensorflow::ImmediateExecutionContext {
68 public:
69 ContextInterface(
70 const tensorflow::SessionOptions& opts,
71 tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
72 bool is_async, bool use_tfrt_distributed_runtime);
73 ~ContextInterface() override;
74
Release()75 void Release() override { delete this; }
76
77 tensorflow::AbstractTensorInterface* CreateInt64Scalar(
78 int64_t value) override;
79 tensorflow::AbstractTensorInterface* CreateUint64Scalar(
80 uint64_t value) override;
81 tensorflow::AbstractTensorInterface* CreateInt32Scalar(
82 int32_t value) override;
83 tensorflow::AbstractTensorInterface* CreateFloatScalar(float value) override;
84 tensorflow::AbstractTensorInterface* CreateDoubleScalar(
85 double value) override;
86 tensorflow::AbstractTensorInterface* CreateHalfScalar(
87 Eigen::half value) override;
88 tensorflow::AbstractTensorInterface* CreateStringScalar(
89 tensorflow::tstring value) override;
90 tensorflow::AbstractTensorInterface* CreateComplex128Scalar(
91 tensorflow::complex128 value) override;
92 tensorflow::AbstractTensorInterface* CreateBoolScalar(bool value) override;
93
94 tensorflow::AbstractTensorInterface* CreateTensor(
95 tensorflow::DataType dtype, absl::Span<const int64_t> dim_sizes) override;
96 tensorflow::AbstractTensorInterface* CreateTensor(
97 tensorflow::DataType dtype, const int64_t* dims, int num_dims, void* data,
98 size_t len, MemoryReleaser memory_releaser,
99 void* memory_releaser_arg) override;
100
101 tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandle(
102 tensorflow::AbstractTensorInterface* t) override;
103 // Create an abstract tensor handle from tensorflow::Tensor.
104 tensorflow::ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
105 tensorflow::Tensor& t, const char* d_name) override;
106
107 // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
108 tensorflow::ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
109 tensorflow::ImmediateExecutionTensorHandle* handle) override;
110
111 tensorflow::ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
112 tensorflow::ImmediateExecutionTensorHandle* handle,
113 const char* device_name, tensorflow::Status* status) override;
114
115 tensorflow::ImmediateExecutionOperation* CreateOperation() override;
116 tensorflow::Status RegisterFunction(tensorflow::AbstractFunction*) override;
117
118 tensorflow::CustomDeviceOpHandler& GetCustomDeviceOpHandler() override;
119
120 tensorflow::Status RegisterCustomDevice(
121 const std::string& name,
122 std::unique_ptr<tensorflow::CustomDevice> device) override;
123
124 tensorflow::FunctionLibraryDefinition* FuncLibDef() override;
125
126 void SetReuseRendezvousForFunctions(
127 bool reuse_rendezvous_for_functions) override;
128
129 void ResetGlobalRendezvousForFunction() override;
130
131 bool UsesTFRT() override;
132
133 void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
134
ListLocalTfDevices()135 std::vector<tensorflow::Device*> ListLocalTfDevices() override {
136 return context_.GetEagerContext()->local_device_mgr()->ListDevices();
137 }
138
139 tensorflow::Status AddDevices(
140 std::vector<std::unique_ptr<tensorflow::Device>> devices) override;
141
142 void ClearCachesAndThreadExecutors() override;
143 void StartStep() override;
144 void EndStep() override;
145
AsyncWait()146 tensorflow::Status AsyncWait() override {
147 GetHostContext()->Quiesce();
148 return tensorflow::Status::OK();
149 }
150
151 tensorflow::Status AddFunctionDef(
152 const tensorflow::FunctionDef& fdef) override;
153 tensorflow::Status AddFunctionDefWithStackTraces(
154 const tensorflow::FunctionDef& fdef,
155 const tensorflow::StackTracesMap& stack_traces) override;
156 std::vector<std::string> ListFunctionNames() override;
157 tensorflow::Status RemoveFunction(const std::string& func) override;
158 const tensorflow::FunctionDef* FindFunctionDef(
159 const std::string& name) const override;
160
161 const tensorflow::DeviceNameUtils::ParsedName& HostCPUParsedName()
162 const override;
163 const std::string& HostCPUName() const override;
164
SetAllowSoftPlacement(bool enable)165 void SetAllowSoftPlacement(bool enable) override {
166 // TODO(tfrt-devs): Move this flag to a common place that can be shared
167 // by current TF and TFRT.
168 GetEagerContext()->SetAllowSoftPlacement(enable);
169 }
SetShouldStoreGraphs(bool value)170 void SetShouldStoreGraphs(bool value) override {
171 GetEagerContext()->SetShouldStoreGraphs(value);
172 }
173
174 tensorflow::Status EnableCollectiveOps(
175 const tensorflow::ServerDef& server_def) override;
176
177 std::unique_ptr<tensorflow::RunMetadata> ExportRunMetadata() override;
178
179 // Find the FunctionDef by the given name and record it in RunMetadata.
180 tensorflow::Status RunMetadataRecordFunction(const std::string& func_name);
181
SetLogDevicePlacement(bool enable)182 void SetLogDevicePlacement(bool enable) override {
183 // TODO(tfrt-devs): Move this flag to a common place that can be shared
184 // by current TF and TFRT.
185 GetEagerContext()->SetLogDevicePlacement(enable);
186 }
187
Executor()188 tensorflow::EagerExecutor& Executor() override {
189 return GetEagerContext()->Executor();
190 }
191 void SetExecutorForThread(tensorflow::EagerExecutor* executor) override;
192
SetThreadLocalDevicePlacementPolicy(tensorflow::ContextDevicePlacementPolicy policy)193 void SetThreadLocalDevicePlacementPolicy(
194 tensorflow::ContextDevicePlacementPolicy policy) override {
195 // TODO(tfrt-devs): Move this flag to a common place that can be shared
196 // by current TF and TFRT.
197 GetEagerContext()->SetThreadLocalDevicePlacementPolicy(policy);
198 }
GetDevicePlacementPolicy()199 tensorflow::ContextDevicePlacementPolicy GetDevicePlacementPolicy()
200 const override {
201 // TODO(tfrt-devs): Move this flag to a common place that can be shared
202 // by current TF and TFRT.
203 return GetEagerContext()->GetDevicePlacementPolicy();
204 }
205
206 CoreRuntime* GetCoreRuntime();
207 tensorflow::Status BuildFunctionRequestContext(
208 tensorflow::tfd::OpKernelRunnerTable* runner_table,
209 RCReference<tfrt::RequestContext>* request_context);
210 tensorflow::Status BuildOpRequestContext(
211 RCReference<tfrt::RequestContext>* request_context);
212 tensorflow::EagerContext* GetEagerContext();
213 const tensorflow::EagerContext* GetEagerContext() const;
214
215 // Selects the op handler to execute the op based on the arguments. This
216 // op handler selection is cheap. But it can be nullptr even it return OK
217 // status.
218 tensorflow::Status SelectOpHandlerFromArguments(
219 const tensorflow::ImmediateExecutionOperation& op,
220 OpHandler** op_handler);
221
222 // Selects the op handler to execute the op based on NodeDef. This op handler
223 // selection is expensive. It will never return nullptr unless there is an
224 // error. Please only invoke this method when the cheap version fails.
225 tensorflow::Status SelectOpHandlerFromNodeDef(
226 const tensorflow::ImmediateExecutionOperation& op,
227 const tensorflow::NodeDef* node_def, OpHandler** op_handler);
228
229 // Returns the chain for current thread.
230 AsyncValueRef<Chain>* GetChain();
231
232 // Indicates sync or async execution.
is_async()233 bool is_async() { return GetEagerContext()->Executor().Async(); }
234
235 // For LLVM style RTTI.
classof(const AbstractContext * op)236 static bool classof(const AbstractContext* op) {
237 return op->getKind() == kTfrt;
238 }
239
GetFunctionCache()240 FunctionCache& GetFunctionCache() { return function_cache_; }
241
GetOpCache()242 OpCache& GetOpCache() { return op_cache_; }
243
244 OpHandler* GetFallbackOpHandler();
245
246 std::vector<std::string> GetLoggedOpsTestonly() override;
247
UseTfrtDistributedRuntime()248 bool UseTfrtDistributedRuntime() { return use_tfrt_distributed_runtime_; }
249
250 #if !defined(IS_MOBILE_PLATFORM)
SetDistributedManager(std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager> distributed)251 void SetDistributedManager(
252 std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
253 distributed) override {
254 distributed_manager_ = std::move(distributed);
255 }
256
GetDistributedManager()257 tensorflow::ImmediateExecutionDistributedManager* GetDistributedManager()
258 override {
259 if (use_tfrt_distributed_runtime_) {
260 return distributed_manager_.get();
261 } else {
262 return context_.GetEagerContext()->GetDistributedManager();
263 }
264 }
265 #endif // !IS_MOBILE_PLATFORM
266
267 private:
268 HostContext* GetHostContext();
269 ResourceContext* GetResourceContext();
270
271 Expected<OpHandler*> GetOpHandler(const char* name);
272
273 TfrtContext context_;
274
275 mutable tensorflow::mutex chain_map_mu_;
276 // TODO(chuanhao): Hook it up with C API to allow user to manage it.
277 // Each caller thread will have its own chain to dispatch ops.
278 std::unordered_map<std::thread::id, AsyncValueRef<Chain>> thread_local_chain_
279 TF_GUARDED_BY(chain_map_mu_);
280
281 std::unique_ptr<EagerOpHandlerSelector> op_handler_selector_;
282
283 // The cache that stores functions (composite CoreRuntimeOps).
284 FunctionCache function_cache_;
285
286 // The cache that stores CoreRuntimeOps. It's separate from function cache
287 // since a primitive CoreRuntimeOp is essentially a stateless function
288 // pointer, and so it doesn't need ref-count to manage its lifetime.
289 OpCache op_cache_;
290
291 mutex run_metadata_mu_;
292 std::unique_ptr<tensorflow::RunMetadata> run_metadata_
293 TFRT_GUARDED_BY(run_metadata_mu_);
294
295 // Use TFRT's implementation of distributed manager.
296 bool use_tfrt_distributed_runtime_ = false;
297
298 // A distributed manager that helps setup, update, and check liveness of
299 // member tasks in the cluster.
300 std::unique_ptr<tensorflow::ImmediateExecutionDistributedManager>
301 distributed_manager_;
302 };
303
304 class TensorInterface : public tensorflow::AbstractTensorInterface {
305 public:
TensorInterface(AsyncValueRef<Tensor> t)306 explicit TensorInterface(AsyncValueRef<Tensor> t) : tensor_(std::move(t)) {}
TensorInterface(tensorflow::Tensor t)307 explicit TensorInterface(tensorflow::Tensor t) : tf_tensor_(std::move(t)) {}
~TensorInterface()308 ~TensorInterface() override {}
309
Release()310 void Release() override { delete this; }
311
312 tensorflow::DataType Type() const override;
313 int NumDims() const override;
314 int64_t Dim(int dim_index) const override;
315 int64_t NumElements() const override;
316 size_t ByteSize() const override;
317 void* Data() const override;
318 bool IsAligned() const override;
319 bool CanMove() const override;
IsTfTensor()320 bool IsTfTensor() const { return !tensor_; }
321 std::string SummarizeValue() const override;
322
323 AsyncValueRef<tfrt::Tensor> TensorRef() const;
TfTensor()324 tensorflow::Tensor& TfTensor() { return tf_tensor_; }
325
326 private:
327 AsyncValueRef<tfrt::Tensor> tensor_;
328 // NOTE(b/167608876): tensorflow::Tensor for handling non-scalar string
329 // tensors, for backward compatibility. This is a temporary workaround until
330 // we find a proper way to unify tensorflow::tstring and
331 // tfrt::StringHostTensor.
332 tensorflow::Tensor tf_tensor_;
333 };
334
335 class TensorHandleInterface
336 : public tensorflow::ImmediateExecutionTensorHandle {
337 public:
338 explicit TensorHandleInterface(Value&& v, CoreRuntime* corert);
339
Release()340 void Release() override { Unref(); }
341
342 tensorflow::DataType DataType() const override;
343 tensorflow::Status Shape(
344 tensorflow::PartialTensorShape* shape) const override;
345 tensorflow::Status NumDims(int* num_dims) const override;
346 tensorflow::Status NumElements(int64_t* num_elements) const override;
347 tensorflow::Status Dim(int dim_index, int64_t* dim) const override;
348
349 // DeviceName represents the device that creates the tensor handle.
350 // Currently the same with BackingDeviceName.
351 // TODO(b/169341326): unify device behavior between current TF and TFRT.
352 const char* DeviceName(tensorflow::Status* status) const override;
353
354 // BackingDeviceName represents the device where the tensor is physically
355 // placed. DeviceName and BackingDeviceName are the same for TFRT.
356 const char* BackingDeviceName(tensorflow::Status* status) const override;
357
358 const char* DeviceType(tensorflow::Status* status) const override;
359
DeviceId(tensorflow::Status * status)360 int DeviceId(tensorflow::Status* status) const override {
361 // TODO(tfrt-devs): implement for tfrt tensor handle.
362 llvm_unreachable("unimplemented method.");
363 }
364
365 tensorflow::AbstractTensorInterface* Resolve(
366 tensorflow::Status* status) override;
367
368 // TODO(b/161897666): Figure out if we can get rid of returning a new
369 // pointer here and just use Ref().
Copy()370 tensorflow::ImmediateExecutionTensorHandle* Copy() override {
371 Ref();
372 return this;
373 }
374
Handle()375 TensorHandle Handle() { return value_.get<TensorHandle>().CopyRef(); }
376
value()377 Value* value() { return &value_; }
378
379 // For LLVM style RTTI.
classof(const tensorflow::AbstractTensorHandle * ptr)380 static bool classof(const tensorflow::AbstractTensorHandle* ptr) {
381 return ptr->getKind() == kTfrt;
382 }
383
384 private:
385 llvm::Optional<const TensorMetadata*> Metadata() const;
386
387 CoreRuntime& corert_;
388
389 // Value of tfrt::TensorHandle.
390 Value value_;
391 };
392
393 template <typename T>
TensorHandleFromInterface(T * handle)394 inline TensorHandleInterface* TensorHandleFromInterface(T* handle) {
395 return tensorflow::down_cast<TensorHandleInterface*>(handle);
396 }
397
398 // TFRT location handler class that simply prints the error and abort the
399 // program on encountering any error. It's primarily for easy debugging
400 // TODO(kkb): Handle errors probably by raising a Python exception.
401 class AbortLocationHandler final : public tfrt::LocationHandler {
402 public:
403 tfrt::Location GetCurrentLocation();
404
405 private:
DecodeLocation(tfrt::Location loc)406 tfrt::DecodedLocation DecodeLocation(tfrt::Location loc) const override {
407 // Return a dummy decoded location.
408 return {};
409 }
410 };
411
412 class OpAttrsInterface : public tensorflow::AbstractOpAttrs {
413 public:
OpAttrsInterface(const OpAttrs * attrs,tensorflow::AttrBuilder * fallback_attrs)414 explicit OpAttrsInterface(const OpAttrs* attrs,
415 tensorflow::AttrBuilder* fallback_attrs)
416 : AbstractOpAttrs(
417 tensorflow::AbstractOpAttrs::AbstractOpAttrsKind::kTfrt),
418 attrs_(attrs),
419 fallback_attrs_(fallback_attrs) {}
~OpAttrsInterface()420 ~OpAttrsInterface() override {}
421
422 void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
423 tensorflow::Status GetTypeList(
424 absl::string_view attr_name,
425 absl::InlinedVector<tensorflow::DataType, 4>* type_list) const override;
426
427 bool GetInt(absl::string_view attr_name, int64_t* result) const override;
428 bool GetFloat(absl::string_view attr_name, float* result) const override;
429 bool GetBool(absl::string_view attr_name, bool* result) const override;
430 bool GetType(absl::string_view attr_name,
431 tensorflow::DataType* result) const override;
432
GetAttrs()433 const OpAttrs* GetAttrs() const { return attrs_; }
434
GetFallbackAttrs()435 const tensorflow::AttrBuilder* GetFallbackAttrs() const {
436 return fallback_attrs_;
437 }
438
439 private:
440 // TODO(fishx): Move ownership to here.
441 const OpAttrs* attrs_;
442
443 // TODO(tfrt-devs): Remove this field and generate NameAttrList from attrs_.
444 // Today it is fine since we will set both attrs and fallback_attrs.
445 const tensorflow::AttrBuilder* fallback_attrs_;
446 };
447
448 class OperationInterface : public tensorflow::ImmediateExecutionOperation {
449 public:
450 // All arguments come from ContextInterface.
451 explicit OperationInterface(ContextInterface* context);
~OperationInterface()452 ~OperationInterface() override {}
453
Release()454 void Release() override { delete this; }
455
Clear()456 void Clear() override { args_.clear(); }
457
458 tensorflow::Status Reset(const char* op,
459 const char* raw_device_name) override;
Name()460 const std::string& Name() const override { return op_name_; }
DeviceName()461 const std::string& DeviceName() const override { return device_name_; }
462 tensorflow::Status SetDeviceName(const char* name) override;
463
GetContext()464 tensorflow::ImmediateExecutionContext* GetContext() const override {
465 return context_;
466 }
HasCustomDeviceInput()467 bool HasCustomDeviceInput() const override {
468 return custom_device_tensor_handle_count_ > 0;
469 }
470
471 tensorflow::Status AddInput(tensorflow::AbstractTensorHandle* input) override;
472 tensorflow::Status AddInputList(
473 absl::Span<tensorflow::AbstractTensorHandle* const> inputs) override;
474 tensorflow::Status SetInput(
475 size_t index, tensorflow::ImmediateExecutionTensorHandle* input) override;
476 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
477 const override;
478 tensorflow::Status Execute(
479 absl::Span<tensorflow::AbstractTensorHandle*> retvals,
480 int* num_retvals) override;
OpDef()481 const tensorflow::OpDef* OpDef() const override { return op_def_; }
NodeDef()482 const tensorflow::NodeDef NodeDef() { return fallback_attrs_.BuildNodeDef(); }
483
484 tensorflow::Status SetAttrString(const char* attr_name, const char* data,
485 size_t length) override;
486 tensorflow::Status SetAttrInt(const char* attr_name, int64_t value) override;
487 tensorflow::Status SetAttrFloat(const char* attr_name, float value) override;
488 tensorflow::Status SetAttrBool(const char* attr_name, bool value) override;
489 tensorflow::Status SetAttrType(const char* attr_name,
490 tensorflow::DataType value) override;
491 tensorflow::Status SetAttrShape(const char* attr_name, const int64_t* dims,
492 const int num_dims) override;
493 tensorflow::Status SetAttrFunction(const char* attr_name,
494 const AbstractOperation* value) override;
495 tensorflow::Status SetAttrFunctionName(const char* attr_name,
496 const char* data,
497 size_t length) override;
498 tensorflow::Status SetAttrTensor(
499 const char* attr_name,
500 tensorflow::AbstractTensorInterface* tensor) override;
501 tensorflow::Status SetAttrStringList(const char* attr_name,
502 const void* const* values,
503 const size_t* lengths,
504 int num_values) override;
505 tensorflow::Status SetAttrFloatList(const char* attr_name,
506 const float* values,
507 int num_values) override;
508 tensorflow::Status SetAttrIntList(const char* attr_name,
509 const int64_t* values,
510 int num_values) override;
511 tensorflow::Status SetAttrTypeList(const char* attr_name,
512 const tensorflow::DataType* values,
513 int num_values) override;
514 tensorflow::Status SetAttrBoolList(const char* attr_name,
515 const unsigned char* values,
516 int num_values) override;
517 tensorflow::Status SetAttrShapeList(const char* attr_name,
518 const int64_t** dims, const int* num_dims,
519 int num_values) override;
520 tensorflow::Status SetAttrFunctionList(
521 const char* attr_name,
522 absl::Span<const AbstractOperation*> values) override;
523
524 tensorflow::Status InputLength(const char* input_name, int* length) override;
525 tensorflow::Status OutputLength(const char* output_name,
526 int* length) override;
527
528 const tensorflow::AbstractOpAttrs* GetOpAttrs() const override;
529 void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override;
530
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)531 void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
532 stack_trace_ = stack_trace;
533 }
534
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)535 void SetCancellationManager(
536 tensorflow::CancellationManager* cancellation_manager) override {
537 // TODO(b/181368626): Support cancellation.
538 }
539
GetStackTrace()540 absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
541 return stack_trace_;
542 }
543
544 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)545 static bool classof(const AbstractOperation* ptr) {
546 return ptr->getKind() == kTfrt;
547 }
548
549 friend class OpCache;
550
551 private:
552 // Initialize op_ field. It can be either a trivial op or a composite op.
553 tensorflow::Status Initialize();
554
555 // Note(fishx): This method is copied from current TF. We use it to infer
556 // attribute like "T" in order to run device placement logic from current TF.
557 void MaybeInferInputAttrs();
558
559 // This field holds a primitive op. If the op represents a function, it
560 // will be held by function_state_ below, and this field will be empty.
561 CoreRuntimeOp* op_;
562 RCReference<FunctionState> function_state_;
563 std::string op_name_;
564 // The device user requested to place the op on.
565 std::string device_name_;
566 bool is_function_;
567 tfrt::BefAttrEncoder bef_attr_encoder_;
568 // TODO(b/165412867): Remove AttrBuilder.
569 tensorflow::AttrBuilder fallback_attrs_;
570 const tensorflow::OpDef* op_def_; // op definition from protobuf
571 OpAttrs attrs_;
572 OpAttrsInterface op_attrs_;
573 SmallVector<
574 tensorflow::core::RefCountPtr<tensorflow::ImmediateExecutionTensorHandle>,
575 8>
576 args_;
577 AbortLocationHandler abort_location_handler_;
578 ContextInterface* const context_;
579 // TODO(kkb): Use tfrt::Location and implement TFRT async stack tracing.
580 absl::optional<tensorflow::ManagedStackTrace> stack_trace_;
581
582 int custom_device_tensor_handle_count_ = 0;
583 };
584
585 } // namespace tf
586 } // namespace tfrt
587
588 #endif // TENSORFLOW_CORE_TFRT_EAGER_C_API_TFRT_H_
589