• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_unified_experimental.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/c/tf_status_helper.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 using tensorflow::string;
32 
33 namespace tensorflow {
34 namespace tracing {
35 typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
36 
GetFactories()37 static FactoriesMap& GetFactories() {
38   static FactoriesMap* factories = new FactoriesMap;
39   return *factories;
40 }
41 
42 static tracing::FactoryFunction default_factory;
43 
RegisterTracingEngineFactory(const string & name,FactoryFunction factory)44 void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
45   assert((!GetFactories().count(name)) ||
46          (GetFactories()[name] == factory) &&
47              "Duplicate tracing factory registration");
48   GetFactories()[name] = factory;
49 }
50 
SetDefaultTracingEngine(const char * name)51 Status SetDefaultTracingEngine(const char* name) {
52   auto entry = GetFactories().find(name);
53   if (entry != GetFactories().end()) {
54     default_factory = GetFactories().find(name)->second;
55     return Status::OK();
56   }
57   string msg = absl::StrCat(
58       "No tracing engine factory has been registered with the key '", name,
59       "' (available: ");
60   // Ensure deterministic (sorted) order in the error message
61   std::set<string> factories_sorted;
62   for (const auto& factory : GetFactories())
63     factories_sorted.insert(factory.first);
64   const char* comma = "";
65   for (const string& factory : factories_sorted) {
66     msg += comma + factory;
67     comma = ", ";
68   }
69   msg += ")";
70 
71   return errors::InvalidArgument(msg.c_str());
72 }
73 
CreateTracingExecutionContext(const char * fn_name,TF_Status * s)74 static TracingContext* CreateTracingExecutionContext(const char* fn_name,
75                                                      TF_Status* s) {
76   if (default_factory) {
77     return default_factory(fn_name, s);
78   }
79   Set_TF_Status_from_Status(
80       s, errors::FailedPrecondition("default_factory is nullptr"));
81   return nullptr;
82 }
83 
84 }  // end namespace tracing
85 }  // end namespace tensorflow
86 
87 // =============================================================================
88 // Public C API entry points
89 //
90 // These are only the generic entry points for the C API. This file does not
91 // have any visibility into the graph/eager implementation and is only providing
92 // C bindings to the abstract classes defined in the
93 // c_api_unified_experimental_internal.h header.
94 //
95 // =============================================================================
96 
97 using tensorflow::AbstractFunction;
98 using tensorflow::AbstractTensorHandle;
99 using tensorflow::DataType;
100 using tensorflow::dyn_cast;
101 using tensorflow::OutputList;
102 using tensorflow::Status;
103 using tensorflow::unwrap;
104 using tensorflow::wrap;
105 using tensorflow::tracing::CreateTracingExecutionContext;
106 using tensorflow::tracing::SetDefaultTracingEngine;
107 using tensorflow::tracing::TracingContext;
108 using tensorflow::tracing::TracingOperation;
109 using tensorflow::tracing::TracingTensorHandle;
110 
TF_SetTracingImplementation(const char * name,TF_Status * s)111 void TF_SetTracingImplementation(const char* name, TF_Status* s) {
112   Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
113 }
114 
115 // Creates a new TensorFlow function, it is an execution context attached to a
116 // given tracing context.
TF_CreateFunction(const char * fn_name,TF_Status * s)117 TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
118   return wrap(CreateTracingExecutionContext(fn_name, s));
119 }
120 
TF_FinalizeFunction(TF_ExecutionContext * ctx,TF_OutputList * outputs,TF_Status * s)121 TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
122                                          TF_OutputList* outputs, TF_Status* s) {
123   AbstractFunction* func;
124   TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
125   if (!tracing_ctx) {
126     Set_TF_Status_from_Status(
127         s, tensorflow::errors::InvalidArgument(
128                "Only TracingContext can be converted into a function."));
129     return nullptr;
130   }
131   Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
132   TF_DeleteExecutionContext(ctx);
133   return wrap(func);
134 }
135 
TF_AddFunctionParameter(TF_ExecutionContext * func,TF_DataType dtype,TF_Shape shape,TF_Status * s)136 TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
137                                            TF_DataType dtype, TF_Shape shape,
138                                            TF_Status* s) {
139   DCHECK_GE(shape.num_dims, -1);
140   TracingTensorHandle* t;
141   TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
142   if (!tracing_ctx) {
143     Set_TF_Status_from_Status(
144         s, tensorflow::errors::InvalidArgument(
145                "TF_AddFunctionParameter must be called on a TracingContext."));
146     return nullptr;
147   }
148   tensorflow::PartialTensorShape partial_shape;
149   if (shape.num_dims != -1) {
150     DCHECK(shape.dim_sizes != nullptr);
151     Status status = tensorflow::PartialTensorShape::MakePartialShape(
152         reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
153         &partial_shape);
154     if (!status.ok()) {
155       Set_TF_Status_from_Status(s, status);
156       return nullptr;
157     }
158   }
159   Set_TF_Status_from_Status(
160       s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
161                                    &t));
162   return wrap(t);
163 }
164 
TF_DeleteExecutionContext(TF_ExecutionContext * c)165 void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
166 
TF_NewAbstractOp(TF_ExecutionContext * c)167 TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
168   return wrap((unwrap(c)->CreateOperation()));
169 }
170 
TF_DeleteAbstractOp(TF_AbstractOp * op)171 void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
172 
TF_DeleteAbstractTensor(TF_AbstractTensor * t)173 void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
174 
TF_NewOutputList()175 TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
TF_DeleteOutputList(TF_OutputList * o)176 void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
TF_OutputListSetNumOutputs(TF_OutputList * o,int num_outputs,TF_Status * s)177 void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
178                                 TF_Status* s) {
179   unwrap(o)->expected_num_outputs = num_outputs;
180   unwrap(o)->outputs.clear();
181   unwrap(o)->outputs.resize(num_outputs);
182 }
TF_OutputListNumOutputs(TF_OutputList * o)183 int TF_OutputListNumOutputs(TF_OutputList* o) {
184   return unwrap(o)->outputs.size();
185 }
TF_OutputListGet(TF_OutputList * o,int i)186 TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
187   return wrap(unwrap(o)->outputs[i]);
188 }
TF_OutputListPushBack(TF_OutputList * o,TF_AbstractTensor * tensor,TF_Status * s)189 void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
190                            TF_Status* s) {
191   unwrap(o)->outputs.push_back(unwrap(tensor));
192 }
193 
TF_AbstractOpSetOpType(TF_AbstractOp * op,const char * const op_type,TF_Status * s)194 void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
195                             TF_Status* s) {
196   Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
197                                                  /*raw_device_name=*/nullptr));
198 }
199 
TF_AbstractOpSetOpName(TF_AbstractOp * op,const char * const op_name,TF_Status * s)200 void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
201                             TF_Status* s) {
202   TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
203   if (!tracing_op) {
204     Set_TF_Status_from_Status(
205         s, tensorflow::errors::InvalidArgument(
206                "TF_AbstractOpSetOpName must be called on a TracingOperation."));
207     return;
208   }
209   Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
210 }
211 
TF_AbstractOpSetAttrType(TF_AbstractOp * op,const char * const attr_name,TF_DataType value,TF_Status * s)212 void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
213                               TF_DataType value, TF_Status* s) {
214   Status status =
215       unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
216   TF_SetStatus(s, static_cast<TF_Code>(status.code()),
217                status.error_message().c_str());
218 }
219 
TF_ExecuteOperation(TF_AbstractOp * op,int num_inputs,TF_AbstractTensor * const * inputs,TF_OutputList * o,TF_Status * s)220 void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
221                          TF_AbstractTensor* const* inputs, TF_OutputList* o,
222                          TF_Status* s) {
223   for (int i = 0; i < num_inputs; i++) {
224     Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
225     if (TF_GetCode(s) != TF_OK) {
226       return;
227     }
228   }
229   int num_outputs = unwrap(o)->expected_num_outputs;
230   Set_TF_Status_from_Status(
231       s, unwrap(op)->Execute(
232              absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
233                                 unwrap(o)->outputs.data()),
234                             unwrap(o)->outputs.size()),
235              &num_outputs));
236 }
237 
TF_DeleteAbstractFunction(TF_AbstractFunction * func)238 void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
239   delete unwrap(func);
240 }
241 
TF_ExecutionContextRegisterFunction(TF_ExecutionContext * ctx,TF_AbstractFunction * func,TF_Status * s)242 void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
243                                          TF_AbstractFunction* func,
244                                          TF_Status* s) {
245   Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
246 }
247