• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
17 #define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/c/conversion_macros.h"
23 #include "tensorflow/c/eager/abstract_context.h"
24 #include "tensorflow/c/eager/abstract_operation.h"
25 #include "tensorflow/c/eager/abstract_tensor_handle.h"
26 #include "tensorflow/c/eager/c_api_unified_experimental.h"
27 #include "tensorflow/c/tf_datatype.h"
28 #include "tensorflow/c/tf_status.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/platform/casts.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
35 // Represents the results of the execution of an operation.
36 struct OutputList {
37   std::vector<AbstractTensorHandle*> outputs;
38   int expected_num_outputs = -1;
39 };
40 
41 namespace tracing {
42 
43 // =============================================================================
44 // Implementation detail for the unified execution APIs for Eager and tracing
45 // backends (graph/MLIR).
46 //
47 // This defines a set of abstract classes that are intended to provide the
48 // functionality of the opaque C types exposed in the public APIs defined in the
49 // `c_api_unified_experimental.h` header.
50 // =============================================================================
51 
52 // Represents either a MlirTensor or a GraphTensor.
53 // This base class does not expose any public methods other than to distinguish
54 // which subclass it actually is. The user is responsible to use the right
55 // type of AbstractTensor in their context (do not pass an MlirTensor to a
56 // GraphContext and vice-versa).
57 class TracingTensorHandle : public AbstractTensorHandle {
58  protected:
TracingTensorHandle(AbstractTensorHandleKind kind)59   explicit TracingTensorHandle(AbstractTensorHandleKind kind)
60       : AbstractTensorHandle(kind) {}
61 
62  public:
63   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)64   static bool classof(const AbstractTensorHandle* ptr) {
65     return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
66   }
67 };
68 
69 // An abstract operation describes an operation by its type, name, and
70 // attributes. It can be "executed" by the context with some input tensors.
71 // It is allowed to reusing the same abstract operation for multiple execution
72 // on a given context, with the same or different input tensors.
73 class TracingOperation : public AbstractOperation {
74  protected:
TracingOperation(AbstractOperationKind kind)75   explicit TracingOperation(AbstractOperationKind kind)
76       : AbstractOperation(kind) {}
77 
78  public:
79   // Sets the name of the operation: this is an optional identifier that is
80   // not intended to carry semantics and preserved/propagated without
81   // guarantees.
82   virtual Status SetOpName(const char* op_name) = 0;
83 
84   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)85   static bool classof(const AbstractOperation* ptr) {
86     return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
87   }
88 };
89 
90 namespace internal {
91 struct TracingOperationDeleter {
operatorTracingOperationDeleter92   void operator()(TracingOperation* p) const {
93     if (p != nullptr) {
94       p->Release();
95     }
96   }
97 };
98 }  // namespace internal
99 
100 using TracingOperationPtr =
101     std::unique_ptr<TracingOperation, internal::TracingOperationDeleter>;
102 
103 // This holds the context for the execution: dispatching operations either to an
104 // MLIR implementation or to a graph implementation.
105 class TracingContext : public AbstractContext {
106  protected:
TracingContext(AbstractContextKind kind)107   explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
108 
109  public:
110   // Add a function parameter and return the corresponding tensor.
111   virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
112                               TracingTensorHandle**) = 0;
113 
114   // Finalize this context and make a function out of it. The context is in a
115   // invalid state after this call and must be destroyed.
116   virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
117 
118   // For LLVM style RTTI.
classof(const AbstractContext * ptr)119   static bool classof(const AbstractContext* ptr) {
120     return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
121   }
122 };
123 
124 typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
125 Status SetDefaultTracingEngine(const char* name);
126 void RegisterTracingEngineFactory(const ::tensorflow::string& name,
127                                   FactoryFunction factory);
128 }  // namespace tracing
129 
130 DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
131 DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
132 DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
133 DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
134 DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
135 }  // namespace tensorflow
136 
137 #endif  // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
138