1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_unified_experimental.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/c/tf_status_helper.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/types.h"
30
31 using tensorflow::string;
32
33 namespace tensorflow {
34 namespace tracing {
35 typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
36
GetFactories()37 static FactoriesMap& GetFactories() {
38 static FactoriesMap* factories = new FactoriesMap;
39 return *factories;
40 }
41
42 static tracing::FactoryFunction default_factory;
43
RegisterTracingEngineFactory(const string & name,FactoryFunction factory)44 void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
45 assert((!GetFactories().count(name)) ||
46 (GetFactories()[name] == factory) &&
47 "Duplicate tracing factory registration");
48 GetFactories()[name] = factory;
49 }
50
SetDefaultTracingEngine(const char * name)51 Status SetDefaultTracingEngine(const char* name) {
52 auto entry = GetFactories().find(name);
53 if (entry != GetFactories().end()) {
54 default_factory = GetFactories().find(name)->second;
55 return Status::OK();
56 }
57 string msg = absl::StrCat(
58 "No tracing engine factory has been registered with the key '", name,
59 "' (available: ");
60 // Ensure deterministic (sorted) order in the error message
61 std::set<string> factories_sorted;
62 for (const auto& factory : GetFactories())
63 factories_sorted.insert(factory.first);
64 const char* comma = "";
65 for (const string& factory : factories_sorted) {
66 msg += comma + factory;
67 comma = ", ";
68 }
69 msg += ")";
70
71 return errors::InvalidArgument(msg.c_str());
72 }
73
CreateTracingExecutionContext(const char * fn_name,TF_Status * s)74 static TracingContext* CreateTracingExecutionContext(const char* fn_name,
75 TF_Status* s) {
76 if (default_factory) {
77 return default_factory(fn_name, s);
78 }
79 Set_TF_Status_from_Status(
80 s, errors::FailedPrecondition("default_factory is nullptr"));
81 return nullptr;
82 }
83
84 } // end namespace tracing
85 } // end namespace tensorflow
86
87 // =============================================================================
88 // Public C API entry points
89 //
90 // These are only the generic entry points for the C API. This file does not
91 // have any visibility into the graph/eager implementation and is only providing
92 // C bindings to the abstract classes defined in the
93 // c_api_unified_experimental_internal.h header.
94 //
95 // =============================================================================
96
97 using tensorflow::AbstractFunction;
98 using tensorflow::AbstractTensorHandle;
99 using tensorflow::DataType;
100 using tensorflow::dyn_cast;
101 using tensorflow::OutputList;
102 using tensorflow::Status;
103 using tensorflow::unwrap;
104 using tensorflow::wrap;
105 using tensorflow::tracing::CreateTracingExecutionContext;
106 using tensorflow::tracing::SetDefaultTracingEngine;
107 using tensorflow::tracing::TracingContext;
108 using tensorflow::tracing::TracingOperation;
109 using tensorflow::tracing::TracingTensorHandle;
110
TF_SetTracingImplementation(const char * name,TF_Status * s)111 void TF_SetTracingImplementation(const char* name, TF_Status* s) {
112 Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
113 }
114
115 // Creates a new TensorFlow function, it is an execution context attached to a
116 // given tracing context.
TF_CreateFunction(const char * fn_name,TF_Status * s)117 TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
118 return wrap(CreateTracingExecutionContext(fn_name, s));
119 }
120
TF_FinalizeFunction(TF_ExecutionContext * ctx,TF_OutputList * outputs,TF_Status * s)121 TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
122 TF_OutputList* outputs, TF_Status* s) {
123 AbstractFunction* func;
124 TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
125 if (!tracing_ctx) {
126 Set_TF_Status_from_Status(
127 s, tensorflow::errors::InvalidArgument(
128 "Only TracingContext can be converted into a function."));
129 return nullptr;
130 }
131 Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
132 TF_DeleteExecutionContext(ctx);
133 return wrap(func);
134 }
135
TF_AddFunctionParameter(TF_ExecutionContext * func,TF_DataType dtype,TF_Shape shape,TF_Status * s)136 TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
137 TF_DataType dtype, TF_Shape shape,
138 TF_Status* s) {
139 DCHECK_GE(shape.num_dims, -1);
140 TracingTensorHandle* t;
141 TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
142 if (!tracing_ctx) {
143 Set_TF_Status_from_Status(
144 s, tensorflow::errors::InvalidArgument(
145 "TF_AddFunctionParameter must be called on a TracingContext."));
146 return nullptr;
147 }
148 tensorflow::PartialTensorShape partial_shape;
149 if (shape.num_dims != -1) {
150 DCHECK(shape.dim_sizes != nullptr);
151 Status status = tensorflow::PartialTensorShape::MakePartialShape(
152 reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
153 &partial_shape);
154 if (!status.ok()) {
155 Set_TF_Status_from_Status(s, status);
156 return nullptr;
157 }
158 }
159 Set_TF_Status_from_Status(
160 s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
161 &t));
162 return wrap(t);
163 }
164
TF_DeleteExecutionContext(TF_ExecutionContext * c)165 void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
166
TF_NewAbstractOp(TF_ExecutionContext * c)167 TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
168 return wrap((unwrap(c)->CreateOperation()));
169 }
170
TF_DeleteAbstractOp(TF_AbstractOp * op)171 void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
172
TF_DeleteAbstractTensor(TF_AbstractTensor * t)173 void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
174
TF_NewOutputList()175 TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
TF_DeleteOutputList(TF_OutputList * o)176 void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
TF_OutputListSetNumOutputs(TF_OutputList * o,int num_outputs,TF_Status * s)177 void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
178 TF_Status* s) {
179 unwrap(o)->expected_num_outputs = num_outputs;
180 unwrap(o)->outputs.clear();
181 unwrap(o)->outputs.resize(num_outputs);
182 }
TF_OutputListNumOutputs(TF_OutputList * o)183 int TF_OutputListNumOutputs(TF_OutputList* o) {
184 return unwrap(o)->outputs.size();
185 }
TF_OutputListGet(TF_OutputList * o,int i)186 TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
187 return wrap(unwrap(o)->outputs[i]);
188 }
TF_OutputListPushBack(TF_OutputList * o,TF_AbstractTensor * tensor,TF_Status * s)189 void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
190 TF_Status* s) {
191 unwrap(o)->outputs.push_back(unwrap(tensor));
192 }
193
TF_AbstractOpSetOpType(TF_AbstractOp * op,const char * const op_type,TF_Status * s)194 void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
195 TF_Status* s) {
196 Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
197 /*raw_device_name=*/nullptr));
198 }
199
TF_AbstractOpSetOpName(TF_AbstractOp * op,const char * const op_name,TF_Status * s)200 void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
201 TF_Status* s) {
202 TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
203 if (!tracing_op) {
204 Set_TF_Status_from_Status(
205 s, tensorflow::errors::InvalidArgument(
206 "TF_AbstractOpSetOpName must be called on a TracingOperation."));
207 return;
208 }
209 Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
210 }
211
TF_AbstractOpSetAttrType(TF_AbstractOp * op,const char * const attr_name,TF_DataType value,TF_Status * s)212 void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
213 TF_DataType value, TF_Status* s) {
214 Status status =
215 unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
216 TF_SetStatus(s, static_cast<TF_Code>(status.code()),
217 status.error_message().c_str());
218 }
219
TF_ExecuteOperation(TF_AbstractOp * op,int num_inputs,TF_AbstractTensor * const * inputs,TF_OutputList * o,TF_Status * s)220 void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
221 TF_AbstractTensor* const* inputs, TF_OutputList* o,
222 TF_Status* s) {
223 for (int i = 0; i < num_inputs; i++) {
224 Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
225 if (TF_GetCode(s) != TF_OK) {
226 return;
227 }
228 }
229 int num_outputs = unwrap(o)->expected_num_outputs;
230 Set_TF_Status_from_Status(
231 s, unwrap(op)->Execute(
232 absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
233 unwrap(o)->outputs.data()),
234 unwrap(o)->outputs.size()),
235 &num_outputs));
236 }
237
TF_DeleteAbstractFunction(TF_AbstractFunction * func)238 void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
239 delete unwrap(func);
240 }
241
TF_ExecutionContextRegisterFunction(TF_ExecutionContext * ctx,TF_AbstractFunction * func,TF_Status * s)242 void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
243 TF_AbstractFunction* func,
244 TF_Status* s) {
245 Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
246 }
247