• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/c/eager/tfe_op_internal.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/shape_inference.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/platform/blocking_counter.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/init_main.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/net.h"
44 #include "tensorflow/core/platform/platform.h"
45 #include "tensorflow/core/platform/strcat.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
48 
49 using tensorflow::FunctionDef;
50 using tensorflow::Node;
51 using tensorflow::NodeBuilder;
52 using tensorflow::Status;
53 using tensorflow::errors::InvalidArgument;
54 
55 namespace {
56 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
57     UniqueFuncPtr;
58 }
59 
60 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)61 static TF_Operation* ToTF_Operation(Node* node) {
62   return static_cast<TF_Operation*>(static_cast<void*>(node));
63 }
64 
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)65 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
66   tensorflow::ConfigProto& config = options->options.config;
67   auto* optimizer_options =
68       config.mutable_graph_options()->mutable_optimizer_options();
69   if (enable) {
70     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
71 
72     // These XLA flags are needed to trigger XLA properly from C (more generally
73     // non-Python) clients. If this API is called again with `enable` set to
74     // false, it is safe to keep these flag values as is.
75     tensorflow::MarkForCompilationPassFlags* flags =
76         tensorflow::GetMarkForCompilationPassFlags();
77     flags->tf_xla_cpu_global_jit = true;
78     flags->tf_xla_min_cluster_size = 1;
79   } else {
80     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
81   }
82 }
83 
TF_SetXlaEnableLazyCompilation(unsigned char enable)84 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
85   tensorflow::BuildXlaOpsPassFlags* flags =
86       tensorflow::GetBuildXlaOpsPassFlags();
87   bool original = flags->tf_xla_enable_lazy_compilation;
88   flags->tf_xla_enable_lazy_compilation = enable;
89   return original;
90 }
91 
TF_SetTfXlaCpuGlobalJit(unsigned char enable)92 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
93   tensorflow::MarkForCompilationPassFlags* flags =
94       tensorflow::GetMarkForCompilationPassFlags();
95   bool original = flags->tf_xla_cpu_global_jit;
96   flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
97   return static_cast<unsigned char>(original);
98 }
99 
TF_SetXlaAutoJitMode(const char * mode)100 void TF_SetXlaAutoJitMode(const char* mode) {
101   tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
102 }
103 
TF_GetXlaConstantFoldingDisabled()104 unsigned char TF_GetXlaConstantFoldingDisabled() {
105   return static_cast<unsigned char>(
106       tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
107 }
108 
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)109 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
110   tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
111       static_cast<bool>(should_enable);
112 }
113 
TF_SetXlaMinClusterSize(int size)114 void TF_SetXlaMinClusterSize(int size) {
115   tensorflow::MarkForCompilationPassFlags* flags =
116       tensorflow::GetMarkForCompilationPassFlags();
117   flags->tf_xla_min_cluster_size = size;
118 }
119 
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)120 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
121                            unsigned char gpu_memory_allow_growth,
122                            unsigned int num_cpu_devices) {
123   tensorflow::ConfigProto config;
124   auto* optimizer_options =
125       config.mutable_graph_options()->mutable_optimizer_options();
126   if (enable_xla_compilation) {
127     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
128 
129     // These XLA flags are needed to trigger XLA properly from C (more generally
130     // non-Python) clients. If this API is called again with `enable` set to
131     // false, it is safe to keep these flag values as is.
132     tensorflow::MarkForCompilationPassFlags* flags =
133         tensorflow::GetMarkForCompilationPassFlags();
134     flags->tf_xla_cpu_global_jit = true;
135     flags->tf_xla_min_cluster_size = 1;
136   } else {
137     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
138   }
139 
140   auto* gpu_options = config.mutable_gpu_options();
141   gpu_options->set_allow_growth(gpu_memory_allow_growth);
142 
143   (*config.mutable_device_count())["CPU"] = num_cpu_devices;
144 
145   // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
146   // threadpool, so that we avoid the possibility of running the runner_ in the
147   // threadpool of GPU event mgr, as that can trigger more callbacks to be
148   // scheduled on that same threadpool, causing a deadlock in cases where the
149   // caller of event_mgr->ThenExecute() blocks on the completion of the callback
150   // (as in the case of ConstOp kernel creation on GPU, which involves copying a
151   // CPU tensor to GPU).
152   // Setting a larger thread pool does not help with the Swift caller, as we use
153   // a different TFE context for each thread of execution (for running graph
154   // functions, and their send/recvs corountines).
155   config.set_inter_op_parallelism_threads(1);
156 
157   TF_Buffer* ret = TF_NewBuffer();
158   TF_CHECK_OK(MessageToBuffer(config, ret));
159   return ret;
160 }
161 
TF_CreateRunOptions(unsigned char enable_full_trace)162 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
163   tensorflow::RunOptions options;
164   if (enable_full_trace) {
165     options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
166   } else {
167     options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
168   }
169   TF_Buffer* ret = TF_NewBuffer();
170   TF_CHECK_OK(MessageToBuffer(options, ret));
171   return ret;
172 }
173 
TF_GraphDebugString(TF_Graph * graph,size_t * len)174 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
175   tensorflow::mutex_lock c(graph->mu);
176   const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
177   *len = debug_str.size();
178   char* ret = static_cast<char*>(malloc(*len + 1));
179   memcpy(ret, debug_str.c_str(), *len + 1);
180   return ret;
181 }
182 
TF_FunctionDebugString(TF_Function * func,size_t * len)183 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
184   const auto& debug_str = DebugString(func->fdef);
185   *len = debug_str.size();
186   char* ret = static_cast<char*>(malloc(*len + 1));
187   memcpy(ret, debug_str.c_str(), *len + 1);
188   return ret;
189 }
190 
191 // On success, returns a set of TF_Function instances from `text_proto` of
192 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
193 //
194 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
195 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)196 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
197     const char* text_proto,
198     std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
199   tensorflow::GraphDef gdef;
200   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
201     status->status = tensorflow::errors::Internal(
202         "Invalid text proto for GraphDef: ", text_proto);
203     return {};
204   }
205   const auto& fdef_lib = gdef.library();
206   if (fdef_lib.gradient_size() > 0) {
207     status->status = tensorflow::errors::Internal(
208         "GradientDef is not supported in reading Dataset related functions: ",
209         text_proto);
210     return {};
211   }
212   std::vector<UniqueFuncPtr> ret;
213   for (const FunctionDef& fdef : fdef_lib.function()) {
214     // Make a copy so that we can mutate it.
215     FunctionDef fdef_to_load = fdef;
216     if (mutate_proto_func) {
217       (*mutate_proto_func)(&fdef_to_load);
218     }
219     VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
220     std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
221     fdef_to_load.SerializeToArray(binary_proto_buf.data(),
222                                   binary_proto_buf.size());
223     TF_Function* func = TF_FunctionImportFunctionDef(
224         binary_proto_buf.data(), binary_proto_buf.size(), status);
225     if (!status->status.ok()) return {};
226     ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
227   }
228   return ret;
229 }
230 
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)231 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
232                                  TF_Status* status) {
233   assert(session);
234   {
235     tensorflow::mutex_lock c(session->graph->mu);
236     VLOG(1) << "Dequeuing named tensor with id " << tensor_id
237             << ", with input graph: "
238             << session->graph->graph.ToGraphDefDebug().DebugString();
239   }
240 
241   TF_Operation* dequeue_op = TF_GraphOperationByName(
242       session->graph,
243       tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
244   if (dequeue_op == nullptr) {
245     status->status = tensorflow::errors::Internal(
246         "Unable to find the dequeue node in the TF graph.");
247     return nullptr;
248   }
249 
250   VLOG(1) << "Running the dequeue op";
251   TF_Output output{dequeue_op, 0};
252   TF_Tensor* ret;
253   TF_SessionRun(session, /*run_options*/ nullptr,
254                 // input related parameters
255                 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
256                 // output related parameters
257                 /*outputs*/ &output, /*output_values*/ &ret,
258                 /*noutputs*/ 1,
259                 /*targets*/ nullptr, /*ntargets*/ 0,
260                 /*run_metadata*/ nullptr, status);
261   if (VLOG_IS_ON(1) && status->status.ok()) {
262     tensorflow::Tensor tensor;
263     if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
264       VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
265     }
266   }
267   return ret;
268 }
269 
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)270 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
271                            TF_Tensor* tensor, TF_Status* status) {
272   assert(session);
273   {
274     tensorflow::mutex_lock c(session->graph->mu);
275     if (VLOG_IS_ON(1)) {
276       VLOG(1) << "Enqueuing named tensor with id " << tensor_id
277               << ", with input graph: "
278               << session->graph->graph.ToGraphDefDebug().DebugString();
279       tensorflow::Tensor internal_tensor;
280       if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
281         VLOG(1) << "Enqueu'ing tensor content: "
282                 << internal_tensor.DebugString();
283       }
284     }
285   }
286 
287   TF_Operation* enqueue_op = TF_GraphOperationByName(
288       session->graph,
289       tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
290   if (enqueue_op == nullptr) {
291     status->status = tensorflow::errors::Internal(
292         "Unable to find the enqueue node in the TF graph.");
293     return;
294   }
295 
296   TF_Operation* placeholder_op = TF_GraphOperationByName(
297       session->graph,
298       tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
299   if (placeholder_op == nullptr) {
300     status->status = tensorflow::errors::Internal(
301         "Unable to find the placeholder node as input to enqueue in the TF "
302         "graph.");
303     return;
304   }
305 
306   VLOG(1) << "Running the enqueue op";
307   TF_Output input{placeholder_op, 0};
308   TF_SessionRun(session, /*run_options*/ nullptr,
309                 // input related parameters
310                 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
311                 // output related parameters
312                 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
313                 /*targets*/ &enqueue_op, /*ntargets*/ 1,
314                 /*run_metadata*/ nullptr, status);
315   VLOG(1) << "Enqueuing is done.";
316 }
317 
TFE_GetServerDef(const char * text_proto,TF_Status * status)318 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
319   tensorflow::ServerDef server_def;
320   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
321                                                          &server_def)) {
322     status->status = tensorflow::errors::Internal(
323         "Invalid text proto for ServerDef: ", text_proto);
324     return nullptr;
325   }
326   status->status = tensorflow::Status();
327   TF_Buffer* ret = TF_NewBuffer();
328   TF_CHECK_OK(MessageToBuffer(server_def, ret));
329   return ret;
330 }
331 
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)332 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
333   status->status = tensorflow::errors::Internal(errMsg);
334 }
335 
336 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
337   using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
338   std::vector<std::string> variable_list;
339 };
340 
TF_NewCheckpointReader(const char * filename,TF_Status * status)341 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
342                                             TF_Status* status) {
343   TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
344   if (!status->status.ok()) {
345     TF_DeleteCheckpointReader(reader);
346     return nullptr;
347   }
348   const auto& m = reader->GetVariableToDataTypeMap();
349   for (auto it = m.begin(); it != m.end(); ++it)
350     reader->variable_list.push_back(it->first);
351   std::sort(reader->variable_list.begin(), reader->variable_list.end());
352   return reader;
353 }
354 
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)355 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
356 
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)357 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
358                                  const char* name) {
359   return reader->HasTensor(name);
360 }
361 
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)362 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
363                                            int index) {
364   return reader->variable_list[index].c_str();
365 }
366 
TF_CheckpointReaderSize(TF_CheckpointReader * reader)367 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
368   return reader->variable_list.size();
369 }
370 
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)371 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
372                                                    const char* name) {
373   const auto& m = reader->GetVariableToDataTypeMap();
374   return static_cast<TF_DataType>(m.at(name));
375 }
376 
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)377 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
378                                         const char* name, TF_Status* status) {
379   std::unique_ptr<tensorflow::Tensor> tensor;
380   reader->GetTensor(name, &tensor, status);
381   if (!status->status.ok()) return nullptr;
382   return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
383 }
384 
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)385 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
386                                          const char* name, int64_t* dims,
387                                          int num_dims, TF_Status* status) {
388   const auto& shape = reader->GetVariableToShapeMap().at(name);
389   int rank = shape.dims();
390   if (num_dims != rank) {
391     status->status = InvalidArgument("Expected rank is ", num_dims,
392                                      " but actual rank is ", rank);
393     return;
394   }
395   for (int i = 0; i < num_dims; i++) {
396     dims[i] = shape.dim_size(i);
397   }
398 }
399 
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)400 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
401                                           const char* name) {
402   const auto& m = reader->GetVariableToShapeMap();
403   return m.at(name).dims();
404 }
405 
406 // This builder is used in the eager API to build a NodeDef.
407 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
408   using tensorflow::AttrBuilder::AttrBuilder;
409   // The string buffers to make sure that any `attr_name` we pass into
410   // `builder->Set()` will outlive the subsequent
411   // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
412   std::set<std::string> attr_names;
413 };
414 
TF_NewAttrBuilder(const char * op_name)415 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
416   return new TF_AttrBuilder(op_name);
417 }
418 
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)419 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
420 
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)421 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
422                            TF_DataType value) {
423   auto iter = builder->attr_names.insert(attr_name).first;
424   builder->Set(*iter, static_cast<tensorflow::DataType>(value));
425 }
426 
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)427 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
428                                const TF_DataType* values, int num_values) {
429   auto iter = builder->attr_names.insert(attr_name).first;
430   builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
431                           reinterpret_cast<const tensorflow::DataType*>(values),
432                           num_values));
433 }
434 
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)435 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
436                                        const char* device_type,
437                                        TF_Status* status) {
438   status->status = tensorflow::FindKernelDef(
439       tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
440       /* def = */ nullptr, /* kernel_class_name = */ nullptr);
441 }
442 
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)443 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
444                                            TF_Status* status) {
445   const tensorflow::OpDef* op_def = nullptr;
446   status->status =
447       tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
448   if (!status->status.ok()) return nullptr;
449 
450   if (input_index >= op_def->input_arg_size() || input_index < 0) {
451     status->status = tensorflow::errors::InvalidArgument(
452         input_index, " out of range for ", op_name);
453     return nullptr;
454   }
455 
456   const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
457 
458   if (input_arg.number_attr().empty()) {
459     status->status = tensorflow::errors::NotFound(
460         op_name, " does not have number_attr() defined.");
461     return nullptr;
462   }
463 
464   // The returned string is owned by OpRegistry, so liveness is not a concern.
465   return input_arg.number_attr().c_str();
466 }
467 
TF_OpIsStateful(const char * op_type,TF_Status * status)468 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
469   const tensorflow::OpRegistrationData* op_reg_data;
470   status->status =
471       tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
472   if (!status->status.ok()) {
473     return 0;
474   }
475   return op_reg_data->op_def.is_stateful();
476 }
477 
TF_InitMain(const char * usage,int * argc,char *** argv)478 void TF_InitMain(const char* usage, int* argc, char*** argv) {
479   tensorflow::port::InitMain(usage, argc, argv);
480 }
481 
TF_PickUnusedPortOrDie()482 int TF_PickUnusedPortOrDie() {
483   return tensorflow::internal::PickUnusedPortOrDie();
484 }
485 
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)486 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
487                                                 void* data, size_t len,
488                                                 TF_Status* status) {
489   auto dtype = static_cast<tensorflow::DataType>(data_type);
490   DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
491 
492   tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
493   std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
494 
495   status->status = tensorflow::Status::OK();
496   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
497 }
498 
499 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)500 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
501                                                    const void* proto,
502                                                    size_t proto_len,
503                                                    TF_Status* status) {
504   tensorflow::ServerDef server_def;
505   if (!server_def.ParseFromArray(proto, proto_len)) {
506     status->status = tensorflow::errors::InvalidArgument(
507         "Invalid tensorflow.ServerDef protocol buffer");
508     return;
509   }
510   status->status =
511       tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps(
512           server_def);
513 }
514 
TFE_AbortCollectiveOps(TFE_Context * ctx,TF_Status * status)515 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
516                                                   TF_Status* status) {
517   tensorflow::EagerContext* context =
518       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
519   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
520   collective_executor_handle->get()->StartAbort(status->status);
521 }
522 
TFE_CollectiveOpsCheckPeerHealth(TFE_Context * ctx,const char * task,int64_t timeout_in_ms,TF_Status * status)523 TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
524     TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
525     TF_Status* status) {
526   tensorflow::EagerContext* context =
527       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
528   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
529   tensorflow::Notification done;
530   collective_executor_handle->get()->remote_access()->CheckPeerHealth(
531       task, timeout_in_ms, [&done, status](const Status& s) {
532         status->status = s;
533         done.Notify();
534       });
535   done.WaitForNotification();
536 }
537 
TF_NewShapeAndTypeList(int num_items)538 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
539   TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
540   result->num_items = num_items;
541   result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
542   return result;
543 }
544 
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)545 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
546                                  const int64_t* dims, int num_dims) {
547   DCHECK(index >= 0 && index < shape_list->num_items);
548   TF_ShapeAndType& shape = shape_list->items[index];
549   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
550   DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
551   shape.num_dims = num_dims;
552   shape.dims = new int64_t[num_dims];
553   memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
554 }
555 
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)556 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
557                                         int index) {
558   DCHECK(index >= 0 && index < shape_list->num_items);
559   TF_ShapeAndType& shape = shape_list->items[index];
560   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
561   shape.num_dims = -1;
562   shape.dims = nullptr;
563 }
564 
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)565 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
566                                  TF_DataType dtype) {
567   DCHECK(index >= 0 && index < shape_list->num_items);
568   TF_ShapeAndType& shape_and_type = shape_list->items[index];
569   shape_and_type.dtype = dtype;
570 }
571 
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)572 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
573   if (shape_list == nullptr) return;
574   for (size_t i = 0; i < shape_list->num_items; ++i) {
575     delete[] shape_list->items[i].dims;
576   }
577   delete[] shape_list->items;
578   delete shape_list;
579 }
580 
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)581 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
582                                     int num_items) {
583   if (shape_list_array == nullptr) return;
584   for (int i = 0; i < num_items; ++i) {
585     TF_DeleteShapeAndTypeList(shape_list_array[i]);
586   }
587   delete[] shape_list_array;
588 }
589 
590 namespace tensorflow {
591 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
592 
593 // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
594 Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
595 }  // namespace tensorflow
596 
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)597 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
598                      TF_Tensor** input_tensors,
599                      TF_ShapeAndTypeList* input_tensors_as_shapes,
600                      TF_ShapeAndTypeList** input_resource_shapes_and_types,
601                      TF_ShapeAndTypeList** output_shapes,
602                      TF_ShapeAndTypeList*** output_resource_shapes_and_types,
603                      TF_Status* status) {
604   using tensorflow::NodeDef;
605   using tensorflow::OpRegistrationData;
606   using tensorflow::Tensor;
607   using tensorflow::shape_inference::DimensionHandle;
608   using tensorflow::shape_inference::InferenceContext;
609   using tensorflow::shape_inference::ShapeAndType;
610   using tensorflow::shape_inference::ShapeHandle;
611 
612   const int num_inputs = input_shapes->num_items;
613   NodeDef node_def;
614   tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
615   node_def.set_name(op->Name());
616   node_def.set_op(op->Name());
617   for (int i = 0; i < num_inputs; ++i) {
618     node_def.add_input("dummy_input");
619   }
620   OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
621 
622   const tensorflow::OpRegistrationData* op_reg_data;
623   status->status =
624       tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
625   if (!status->status.ok()) return;
626 
627   // Initialize a input_tensor vector with `nullptr` values.
628   std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
629   // A vector to keep track of newly created `tf::Tensor` objects.
630   std::vector<Tensor> all_input_tensors;
631   // Update the vector with information from `input_tensors` if provided.
632   if (input_tensors != nullptr) {
633     // Note that we take the address of the elements in `all_input_tensors`
634     // below. Allocate enough space so that no reallocation happens, which will
635     // make the pointers invalid.
636     all_input_tensors.reserve(num_inputs);
637     for (int i = 0; i < num_inputs; ++i) {
638       if (input_tensors[i] == nullptr) continue;
639       all_input_tensors.emplace_back();
640       Tensor& input_tensor = all_input_tensors.back();
641       status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
642       if (!status->status.ok()) return;
643       input_tensors_vector[i] = &input_tensor;
644     }
645   }
646 
647   // Create an inference context with dummy values, which will be updated later.
648   InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
649                      std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
650                      {},
651                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
652 
653   // Set input_shapes.
654   for (int i = 0; i < num_inputs; ++i) {
655     std::vector<DimensionHandle> dims;
656     const TF_ShapeAndType& input_shape = input_shapes->items[i];
657     if (input_shape.num_dims == InferenceContext::kUnknownRank) {
658       c.SetInput(i, c.UnknownShape());
659       continue;
660     }
661     dims.reserve(input_shape.num_dims);
662     for (int j = 0; j < input_shape.num_dims; ++j) {
663       dims.push_back(c.MakeDim(input_shape.dims[j]));
664     }
665     c.SetInput(i, c.MakeShape(dims));
666   }
667 
668   // TODO(bgogul): Handle input_tensors_as_shapes.
669   // TODO(bgogul): Handle input_resource_shapes_and_types.
670 
671   status->status = c.construction_status();
672   if (!status->status.ok()) return;
673 
674   if (op_reg_data->shape_inference_fn == nullptr) {
675     status->status =
676         InvalidArgument("No shape inference function exists for op '",
677                         node_def.op(), "', did you forget to define it?");
678     return;
679   }
680 
681   status->status = c.Run(op_reg_data->shape_inference_fn);
682   if (!status->status.ok()) return;
683 
684   // Set output_shapes.
685   TF_ShapeAndTypeList* output_shapes_result =
686       TF_NewShapeAndTypeList(c.num_outputs());
687   for (int i = 0; i < c.num_outputs(); ++i) {
688     ShapeHandle shape_handle = c.output(i);
689     TF_ShapeAndType& shape = output_shapes_result->items[i];
690     shape.num_dims = c.Rank(shape_handle);
691     if (shape.num_dims == InferenceContext::kUnknownRank) {
692       shape.dims = nullptr;
693       continue;
694     }
695     shape.dims = new int64_t[shape.num_dims];
696     for (size_t j = 0; j < shape.num_dims; ++j) {
697       shape.dims[j] = c.Value(c.Dim(shape_handle, j));
698     }
699   }
700   if (output_shapes != nullptr) *output_shapes = output_shapes_result;
701 
702   // TODO(bgogul): Set output_resource_shapes_and_types.
703 }
704 
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)705 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
706     TF_ImportGraphDefOptions* opts, unsigned char enable) {
707   opts->opts.validate_colocation_constraints = enable;
708 }
709 
710 // Load a Pluggable Device library.
711 // On success, returns the handle to library in result and return OK from the
712 // function. Otherwise return nullptr in result and error Status from the
713 // function.
714 //
715 // If `library_filename` has already been loaded, we return a cached handle.
716 // Device and Kernels/Ops are registered as globals when a library is loaded
717 // for the first time.
TF_LoadPluggableDeviceLibrary(const char * library_filename,TF_Status * status)718 TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
719                                           TF_Status* status) {
720 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
721   status->status = tensorflow::errors::Unimplemented(
722       "PluggableDevice plugin functionality is not supported on mobile");
723   return nullptr;
724 #else
725   TF_Library* lib_handle = new TF_Library;
726   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
727   static std::unordered_map<std::string, void*>* loaded_libs =
728       new std::unordered_map<std::string, void*>();
729   tensorflow::Env* env = tensorflow::Env::Default();
730   {
731     tensorflow::mutex_lock lock(mu);
732     auto it = loaded_libs->find(library_filename);
733     if (it != loaded_libs->end()) {
734       lib_handle->lib_handle = it->second;
735     } else {
736       status->status =
737           env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
738       if (!status->status.ok()) {
739         delete lib_handle;
740         return nullptr;
741       }
742     }
743     return lib_handle;
744   }
745 #endif
746 }
747 
TF_DeletePluggableDeviceLibraryHandle(TF_Library * lib_handle)748 void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
749   delete lib_handle;
750 }
751