1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ 17 #define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ 18 19 #include <vector> 20 21 #include "tensorflow/c/c_api.h" 22 #include "tensorflow/c/conversion_macros.h" 23 #include "tensorflow/c/eager/abstract_context.h" 24 #include "tensorflow/c/eager/abstract_operation.h" 25 #include "tensorflow/c/eager/abstract_tensor_handle.h" 26 #include "tensorflow/c/eager/c_api_unified_experimental.h" 27 #include "tensorflow/c/tf_datatype.h" 28 #include "tensorflow/c/tf_status.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/platform/casts.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace tensorflow { 34 35 // Represents the results of the execution of an operation. 36 struct OutputList { 37 std::vector<AbstractTensorHandle*> outputs; 38 int expected_num_outputs = -1; 39 }; 40 41 namespace tracing { 42 43 // ============================================================================= 44 // Implementation detail for the unified execution APIs for Eager and tracing 45 // backends (graph/MLIR). 46 // 47 // This defines a set of abstract classes that are intended to provide the 48 // functionality of the opaque C types exposed in the public APIs defined in the 49 // `c_api_unified_experimental.h` header. 50 // ============================================================================= 51 52 // Represents either a MlirTensor or a GraphTensor. 53 // This base class does not expose any public methods other than to distinguish 54 // which subclass it actually is. The user is responsible to use the right 55 // type of AbstractTensor in their context (do not pass an MlirTensor to a 56 // GraphContext and vice-versa). 57 class TracingTensorHandle : public AbstractTensorHandle { 58 protected: TracingTensorHandle(AbstractTensorHandleKind kind)59 explicit TracingTensorHandle(AbstractTensorHandleKind kind) 60 : AbstractTensorHandle(kind) {} 61 62 public: 63 // For LLVM style RTTI. classof(const AbstractTensorHandle * ptr)64 static bool classof(const AbstractTensorHandle* ptr) { 65 return ptr->getKind() == kGraph || ptr->getKind() == kMlir; 66 } 67 }; 68 69 // An abstract operation describes an operation by its type, name, and 70 // attributes. It can be "executed" by the context with some input tensors. 71 // It is allowed to reusing the same abstract operation for multiple execution 72 // on a given context, with the same or different input tensors. 73 class TracingOperation : public AbstractOperation { 74 protected: TracingOperation(AbstractOperationKind kind)75 explicit TracingOperation(AbstractOperationKind kind) 76 : AbstractOperation(kind) {} 77 78 public: 79 // Sets the name of the operation: this is an optional identifier that is 80 // not intended to carry semantics and preserved/propagated without 81 // guarantees. 82 virtual Status SetOpName(const char* op_name) = 0; 83 84 // For LLVM style RTTI. classof(const AbstractOperation * ptr)85 static bool classof(const AbstractOperation* ptr) { 86 return ptr->getKind() == kGraph || ptr->getKind() == kMlir; 87 } 88 }; 89 90 namespace internal { 91 struct TracingOperationDeleter { operatorTracingOperationDeleter92 void operator()(TracingOperation* p) const { 93 if (p != nullptr) { 94 p->Release(); 95 } 96 } 97 }; 98 } // namespace internal 99 100 using TracingOperationPtr = 101 std::unique_ptr<TracingOperation, internal::TracingOperationDeleter>; 102 103 // This holds the context for the execution: dispatching operations either to an 104 // MLIR implementation or to a graph implementation. 105 class TracingContext : public AbstractContext { 106 protected: TracingContext(AbstractContextKind kind)107 explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {} 108 109 public: 110 // Add a function parameter and return the corresponding tensor. 111 virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape, 112 TracingTensorHandle**) = 0; 113 114 // Finalize this context and make a function out of it. The context is in a 115 // invalid state after this call and must be destroyed. 116 virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0; 117 118 // For LLVM style RTTI. classof(const AbstractContext * ptr)119 static bool classof(const AbstractContext* ptr) { 120 return ptr->getKind() == kGraph || ptr->getKind() == kMlir; 121 } 122 }; 123 124 typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); 125 Status SetDefaultTracingEngine(const char* name); 126 void RegisterTracingEngineFactory(const ::tensorflow::string& name, 127 FactoryFunction factory); 128 } // namespace tracing 129 130 DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext) 131 DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor) 132 DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction) 133 DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp) 134 DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList) 135 } // namespace tensorflow 136 137 #endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ 138