• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/c/eager/tfe_op_internal.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
33 #include "tensorflow/core/framework/collective.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/graph/node_builder.h"
39 #include "tensorflow/core/platform/blocking_counter.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/init_main.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/net.h"
45 #include "tensorflow/core/platform/platform.h"
46 #include "tensorflow/core/platform/strcat.h"
47 #include "tensorflow/core/protobuf/config.pb.h"
48 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
49 
50 using tensorflow::FunctionDef;
51 using tensorflow::Node;
52 using tensorflow::NodeBuilder;
53 using tensorflow::Status;
54 using tensorflow::errors::InvalidArgument;
55 
56 namespace {
57 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
58     UniqueFuncPtr;
59 }
60 
61 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)62 static TF_Operation* ToTF_Operation(Node* node) {
63   return static_cast<TF_Operation*>(static_cast<void*>(node));
64 }
65 
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)66 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
67   tensorflow::ConfigProto& config = options->options.config;
68   auto* optimizer_options =
69       config.mutable_graph_options()->mutable_optimizer_options();
70   if (enable) {
71     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
72 
73     // These XLA flags are needed to trigger XLA properly from C (more generally
74     // non-Python) clients. If this API is called again with `enable` set to
75     // false, it is safe to keep these flag values as is.
76     tensorflow::MarkForCompilationPassFlags* flags =
77         tensorflow::GetMarkForCompilationPassFlags();
78     flags->tf_xla_cpu_global_jit = true;
79     flags->tf_xla_min_cluster_size = 1;
80   } else {
81     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
82   }
83 }
84 
TF_SetXlaEnableLazyCompilation(unsigned char enable)85 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
86   tensorflow::BuildXlaOpsPassFlags* flags =
87       tensorflow::GetBuildXlaOpsPassFlags();
88   bool original = flags->tf_xla_enable_lazy_compilation;
89   flags->tf_xla_enable_lazy_compilation = enable;
90   return original;
91 }
92 
TF_SetTfXlaCpuGlobalJit(unsigned char enable)93 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
94   tensorflow::MarkForCompilationPassFlags* flags =
95       tensorflow::GetMarkForCompilationPassFlags();
96   bool original = flags->tf_xla_cpu_global_jit;
97   flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
98   return static_cast<unsigned char>(original);
99 }
100 
TF_SetXlaAutoJitMode(const char * mode)101 void TF_SetXlaAutoJitMode(const char* mode) {
102   tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
103 }
104 
TF_GetXlaConstantFoldingDisabled()105 unsigned char TF_GetXlaConstantFoldingDisabled() {
106   return static_cast<unsigned char>(
107       tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
108 }
109 
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)110 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
111   tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
112       static_cast<bool>(should_enable);
113 }
114 
TF_SetXlaMinClusterSize(int size)115 void TF_SetXlaMinClusterSize(int size) {
116   tensorflow::MarkForCompilationPassFlags* flags =
117       tensorflow::GetMarkForCompilationPassFlags();
118   flags->tf_xla_min_cluster_size = size;
119 }
120 
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)121 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
122                            unsigned char gpu_memory_allow_growth,
123                            unsigned int num_cpu_devices) {
124   tensorflow::ConfigProto config;
125   auto* optimizer_options =
126       config.mutable_graph_options()->mutable_optimizer_options();
127   if (enable_xla_compilation) {
128     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
129 
130     // These XLA flags are needed to trigger XLA properly from C (more generally
131     // non-Python) clients. If this API is called again with `enable` set to
132     // false, it is safe to keep these flag values as is.
133     tensorflow::MarkForCompilationPassFlags* flags =
134         tensorflow::GetMarkForCompilationPassFlags();
135     flags->tf_xla_cpu_global_jit = true;
136     flags->tf_xla_min_cluster_size = 1;
137   } else {
138     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
139   }
140 
141   auto* gpu_options = config.mutable_gpu_options();
142   gpu_options->set_allow_growth(gpu_memory_allow_growth);
143 
144   (*config.mutable_device_count())["CPU"] = num_cpu_devices;
145 
146   // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
147   // threadpool, so that we avoid the possibility of running the runner_ in the
148   // threadpool of GPU event mgr, as that can trigger more callbacks to be
149   // scheduled on that same threadpool, causing a deadlock in cases where the
150   // caller of event_mgr->ThenExecute() blocks on the completion of the callback
151   // (as in the case of ConstOp kernel creation on GPU, which involves copying a
152   // CPU tensor to GPU).
153   // Setting a larger thread pool does not help with the Swift caller, as we use
154   // a different TFE context for each thread of execution (for running graph
155   // functions, and their send/recvs corountines).
156   config.set_inter_op_parallelism_threads(1);
157 
158   TF_Buffer* ret = TF_NewBuffer();
159   TF_CHECK_OK(MessageToBuffer(config, ret));
160   return ret;
161 }
162 
TF_CreateRunOptions(unsigned char enable_full_trace)163 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
164   tensorflow::RunOptions options;
165   if (enable_full_trace) {
166     options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
167   } else {
168     options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
169   }
170   TF_Buffer* ret = TF_NewBuffer();
171   TF_CHECK_OK(MessageToBuffer(options, ret));
172   return ret;
173 }
174 
TF_GraphDebugString(TF_Graph * graph,size_t * len)175 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
176   tensorflow::mutex_lock c(graph->mu);
177   const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
178   *len = debug_str.size();
179   char* ret = static_cast<char*>(malloc(*len + 1));
180   memcpy(ret, debug_str.c_str(), *len + 1);
181   return ret;
182 }
183 
TF_FunctionDebugString(TF_Function * func,size_t * len)184 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
185   const auto& debug_str = DebugString(func->fdef);
186   *len = debug_str.size();
187   char* ret = static_cast<char*>(malloc(*len + 1));
188   memcpy(ret, debug_str.c_str(), *len + 1);
189   return ret;
190 }
191 
192 // On success, returns a set of TF_Function instances from `text_proto` of
193 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
194 //
195 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
196 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)197 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
198     const char* text_proto,
199     std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
200   tensorflow::GraphDef gdef;
201   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
202     status->status = tensorflow::errors::Internal(
203         "Invalid text proto for GraphDef: ", text_proto);
204     return {};
205   }
206   const auto& fdef_lib = gdef.library();
207   if (fdef_lib.gradient_size() > 0) {
208     status->status = tensorflow::errors::Internal(
209         "GradientDef is not supported in reading Dataset related functions: ",
210         text_proto);
211     return {};
212   }
213   std::vector<UniqueFuncPtr> ret;
214   for (const FunctionDef& fdef : fdef_lib.function()) {
215     // Make a copy so that we can mutate it.
216     FunctionDef fdef_to_load = fdef;
217     if (mutate_proto_func) {
218       (*mutate_proto_func)(&fdef_to_load);
219     }
220     VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
221     std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
222     fdef_to_load.SerializeToArray(binary_proto_buf.data(),
223                                   binary_proto_buf.size());
224     TF_Function* func = TF_FunctionImportFunctionDef(
225         binary_proto_buf.data(), binary_proto_buf.size(), status);
226     if (!status->status.ok()) return {};
227     ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
228   }
229   return ret;
230 }
231 
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)232 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
233                                  TF_Status* status) {
234   assert(session);
235   {
236     tensorflow::mutex_lock c(session->graph->mu);
237     VLOG(1) << "Dequeuing named tensor with id " << tensor_id
238             << ", with input graph: "
239             << session->graph->graph.ToGraphDefDebug().DebugString();
240   }
241 
242   TF_Operation* dequeue_op = TF_GraphOperationByName(
243       session->graph,
244       tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
245   if (dequeue_op == nullptr) {
246     status->status = tensorflow::errors::Internal(
247         "Unable to find the dequeue node in the TF graph.");
248     return nullptr;
249   }
250 
251   VLOG(1) << "Running the dequeue op";
252   TF_Output output{dequeue_op, 0};
253   TF_Tensor* ret;
254   TF_SessionRun(session, /*run_options*/ nullptr,
255                 // input related parameters
256                 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
257                 // output related parameters
258                 /*outputs*/ &output, /*output_values*/ &ret,
259                 /*noutputs*/ 1,
260                 /*targets*/ nullptr, /*ntargets*/ 0,
261                 /*run_metadata*/ nullptr, status);
262   if (VLOG_IS_ON(1) && status->status.ok()) {
263     tensorflow::Tensor tensor;
264     if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
265       VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
266     }
267   }
268   return ret;
269 }
270 
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)271 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
272                            TF_Tensor* tensor, TF_Status* status) {
273   assert(session);
274   {
275     tensorflow::mutex_lock c(session->graph->mu);
276     if (VLOG_IS_ON(1)) {
277       VLOG(1) << "Enqueuing named tensor with id " << tensor_id
278               << ", with input graph: "
279               << session->graph->graph.ToGraphDefDebug().DebugString();
280       tensorflow::Tensor internal_tensor;
281       if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
282         VLOG(1) << "Enqueu'ing tensor content: "
283                 << internal_tensor.DebugString();
284       }
285     }
286   }
287 
288   TF_Operation* enqueue_op = TF_GraphOperationByName(
289       session->graph,
290       tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
291   if (enqueue_op == nullptr) {
292     status->status = tensorflow::errors::Internal(
293         "Unable to find the enqueue node in the TF graph.");
294     return;
295   }
296 
297   TF_Operation* placeholder_op = TF_GraphOperationByName(
298       session->graph,
299       tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
300   if (placeholder_op == nullptr) {
301     status->status = tensorflow::errors::Internal(
302         "Unable to find the placeholder node as input to enqueue in the TF "
303         "graph.");
304     return;
305   }
306 
307   VLOG(1) << "Running the enqueue op";
308   TF_Output input{placeholder_op, 0};
309   TF_SessionRun(session, /*run_options*/ nullptr,
310                 // input related parameters
311                 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
312                 // output related parameters
313                 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
314                 /*targets*/ &enqueue_op, /*ntargets*/ 1,
315                 /*run_metadata*/ nullptr, status);
316   VLOG(1) << "Enqueuing is done.";
317 }
318 
TFE_GetServerDef(const char * text_proto,TF_Status * status)319 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
320   tensorflow::ServerDef server_def;
321   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
322                                                          &server_def)) {
323     status->status = tensorflow::errors::Internal(
324         "Invalid text proto for ServerDef: ", text_proto);
325     return nullptr;
326   }
327   status->status = tensorflow::Status();
328   TF_Buffer* ret = TF_NewBuffer();
329   TF_CHECK_OK(MessageToBuffer(server_def, ret));
330   return ret;
331 }
332 
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)333 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
334   status->status = tensorflow::errors::Internal(errMsg);
335 }
336 
337 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
338   using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
339   std::vector<std::string> variable_list;
340 };
341 
TF_NewCheckpointReader(const char * filename,TF_Status * status)342 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
343                                             TF_Status* status) {
344   TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
345   if (!status->status.ok()) {
346     TF_DeleteCheckpointReader(reader);
347     return nullptr;
348   }
349   const auto& m = reader->GetVariableToDataTypeMap();
350   for (auto it = m.begin(); it != m.end(); ++it)
351     reader->variable_list.push_back(it->first);
352   std::sort(reader->variable_list.begin(), reader->variable_list.end());
353   return reader;
354 }
355 
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)356 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
357 
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)358 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
359                                  const char* name) {
360   return reader->HasTensor(name);
361 }
362 
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)363 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
364                                            int index) {
365   return reader->variable_list[index].c_str();
366 }
367 
TF_CheckpointReaderSize(TF_CheckpointReader * reader)368 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
369   return reader->variable_list.size();
370 }
371 
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)372 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
373                                                    const char* name) {
374   const auto& m = reader->GetVariableToDataTypeMap();
375   return static_cast<TF_DataType>(m.at(name));
376 }
377 
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)378 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
379                                         const char* name, TF_Status* status) {
380   std::unique_ptr<tensorflow::Tensor> tensor;
381   reader->GetTensor(name, &tensor, status);
382   if (!status->status.ok()) return nullptr;
383   return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
384 }
385 
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)386 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
387                                          const char* name, int64_t* dims,
388                                          int num_dims, TF_Status* status) {
389   const auto& shape = reader->GetVariableToShapeMap().at(name);
390   int rank = shape.dims();
391   if (num_dims != rank) {
392     status->status = InvalidArgument("Expected rank is ", num_dims,
393                                      " but actual rank is ", rank);
394     return;
395   }
396   for (int i = 0; i < num_dims; i++) {
397     dims[i] = shape.dim_size(i);
398   }
399 }
400 
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)401 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
402                                           const char* name) {
403   const auto& m = reader->GetVariableToShapeMap();
404   return m.at(name).dims();
405 }
406 
407 // This builder is used in the eager API to build a NodeDef.
408 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
409   using tensorflow::AttrBuilder::AttrBuilder;
410   // The string buffers to make sure that any `attr_name` we pass into
411   // `builder->Set()` will outlive the subsequent
412   // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
413   std::set<std::string> attr_names;
414 };
415 
TF_NewAttrBuilder(const char * op_name)416 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
417   return new TF_AttrBuilder(op_name);
418 }
419 
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)420 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
421 
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)422 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
423                            TF_DataType value) {
424   auto iter = builder->attr_names.insert(attr_name).first;
425   builder->Set(*iter, static_cast<tensorflow::DataType>(value));
426 }
427 
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)428 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
429                                const TF_DataType* values, int num_values) {
430   auto iter = builder->attr_names.insert(attr_name).first;
431   builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
432                           reinterpret_cast<const tensorflow::DataType*>(values),
433                           num_values));
434 }
435 
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)436 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
437                                        const char* device_type,
438                                        TF_Status* status) {
439   status->status = tensorflow::FindKernelDef(
440       tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
441       /* def = */ nullptr, /* kernel_class_name = */ nullptr);
442 }
443 
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)444 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
445                                            TF_Status* status) {
446   const tensorflow::OpDef* op_def = nullptr;
447   status->status =
448       tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
449   if (!status->status.ok()) return nullptr;
450 
451   if (input_index >= op_def->input_arg_size() || input_index < 0) {
452     status->status = tensorflow::errors::InvalidArgument(
453         input_index, " out of range for ", op_name);
454     return nullptr;
455   }
456 
457   const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
458 
459   if (input_arg.number_attr().empty()) {
460     status->status = tensorflow::errors::NotFound(
461         op_name, " does not have number_attr() defined.");
462     return nullptr;
463   }
464 
465   // The returned string is owned by OpRegistry, so liveness is not a concern.
466   return input_arg.number_attr().c_str();
467 }
468 
TF_OpIsStateful(const char * op_type,TF_Status * status)469 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
470   const tensorflow::OpRegistrationData* op_reg_data;
471   status->status =
472       tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
473   if (!status->status.ok()) {
474     return 0;
475   }
476   return op_reg_data->op_def.is_stateful();
477 }
478 
TF_InitMain(const char * usage,int * argc,char *** argv)479 void TF_InitMain(const char* usage, int* argc, char*** argv) {
480   tensorflow::port::InitMain(usage, argc, argv);
481 }
482 
TF_PickUnusedPortOrDie()483 int TF_PickUnusedPortOrDie() {
484   return tensorflow::internal::PickUnusedPortOrDie();
485 }
486 
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)487 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
488                                                 void* data, size_t len,
489                                                 TF_Status* status) {
490   auto dtype = static_cast<tensorflow::DataType>(data_type);
491   DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
492 
493   tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
494   std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
495 
496   status->status = tensorflow::Status::OK();
497   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
498 }
499 
500 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)501 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
502                                                    const void* proto,
503                                                    size_t proto_len,
504                                                    TF_Status* status) {
505   tensorflow::ServerDef server_def;
506   if (!server_def.ParseFromArray(proto, proto_len)) {
507     status->status = tensorflow::errors::InvalidArgument(
508         "Invalid tensorflow.ServerDef protocol buffer");
509     return;
510   }
511   status->status = tensorflow::unwrap(ctx)->EnableCollectiveOps(server_def);
512 }
513 
TFE_AbortCollectiveOps(TFE_Context * ctx,TF_Status * status)514 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
515                                                   TF_Status* status) {
516   tensorflow::EagerContext* context =
517       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
518   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
519   collective_executor_handle->get()->StartAbort(status->status);
520 }
521 
TFE_CollectiveOpsCheckPeerHealth(TFE_Context * ctx,const char * task,int64_t timeout_in_ms,TF_Status * status)522 TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
523     TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
524     TF_Status* status) {
525   tensorflow::EagerContext* context =
526       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
527   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
528   tensorflow::Notification done;
529   collective_executor_handle->get()->remote_access()->CheckPeerHealth(
530       task, timeout_in_ms, [&done, status](const Status& s) {
531         status->status = s;
532         done.Notify();
533       });
534   done.WaitForNotification();
535 }
536 
TF_NewShapeAndTypeList(int num_items)537 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
538   TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
539   result->num_items = num_items;
540   result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
541   return result;
542 }
543 
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)544 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
545                                  const int64_t* dims, int num_dims) {
546   DCHECK(index >= 0 && index < shape_list->num_items);
547   TF_ShapeAndType& shape = shape_list->items[index];
548   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
549   DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
550   shape.num_dims = num_dims;
551   shape.dims = new int64_t[num_dims];
552   memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
553 }
554 
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)555 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
556                                         int index) {
557   DCHECK(index >= 0 && index < shape_list->num_items);
558   TF_ShapeAndType& shape = shape_list->items[index];
559   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
560   shape.num_dims = -1;
561   shape.dims = nullptr;
562 }
563 
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)564 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
565                                  TF_DataType dtype) {
566   DCHECK(index >= 0 && index < shape_list->num_items);
567   TF_ShapeAndType& shape_and_type = shape_list->items[index];
568   shape_and_type.dtype = dtype;
569 }
570 
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)571 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
572   if (shape_list == nullptr) return;
573   for (size_t i = 0; i < shape_list->num_items; ++i) {
574     delete[] shape_list->items[i].dims;
575   }
576   delete[] shape_list->items;
577   delete shape_list;
578 }
579 
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)580 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
581                                     int num_items) {
582   if (shape_list_array == nullptr) return;
583   for (int i = 0; i < num_items; ++i) {
584     TF_DeleteShapeAndTypeList(shape_list_array[i]);
585   }
586   delete[] shape_list_array;
587 }
588 
589 namespace tensorflow {
590 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
591 
592 // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
593 Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
594 }  // namespace tensorflow
595 
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)596 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
597                      TF_Tensor** input_tensors,
598                      TF_ShapeAndTypeList* input_tensors_as_shapes,
599                      TF_ShapeAndTypeList** input_resource_shapes_and_types,
600                      TF_ShapeAndTypeList** output_shapes,
601                      TF_ShapeAndTypeList*** output_resource_shapes_and_types,
602                      TF_Status* status) {
603   using tensorflow::NodeDef;
604   using tensorflow::OpRegistrationData;
605   using tensorflow::Tensor;
606   using tensorflow::shape_inference::DimensionHandle;
607   using tensorflow::shape_inference::InferenceContext;
608   using tensorflow::shape_inference::ShapeAndType;
609   using tensorflow::shape_inference::ShapeHandle;
610 
611   const int num_inputs = input_shapes->num_items;
612   NodeDef node_def;
613   tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
614   node_def.set_name(op->Name());
615   node_def.set_op(op->Name());
616   for (int i = 0; i < num_inputs; ++i) {
617     node_def.add_input("dummy_input");
618   }
619   OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
620 
621   const tensorflow::OpRegistrationData* op_reg_data;
622   status->status =
623       tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
624   if (!status->status.ok()) return;
625 
626   // Initialize a input_tensor vector with `nullptr` values.
627   std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
628   // A vector to keep track of newly created `tf::Tensor` objects.
629   std::vector<Tensor> all_input_tensors;
630   // Update the vector with information from `input_tensors` if provided.
631   if (input_tensors != nullptr) {
632     // Note that we take the address of the elements in `all_input_tensors`
633     // below. Allocate enough space so that no reallocation happens, which will
634     // make the pointers invalid.
635     all_input_tensors.reserve(num_inputs);
636     for (int i = 0; i < num_inputs; ++i) {
637       if (input_tensors[i] == nullptr) continue;
638       all_input_tensors.emplace_back();
639       Tensor& input_tensor = all_input_tensors.back();
640       status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
641       if (!status->status.ok()) return;
642       input_tensors_vector[i] = &input_tensor;
643     }
644   }
645 
646   // Create an inference context with dummy values, which will be updated later.
647   InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
648                      std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
649                      {},
650                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
651 
652   // Set input_shapes.
653   for (int i = 0; i < num_inputs; ++i) {
654     std::vector<DimensionHandle> dims;
655     const TF_ShapeAndType& input_shape = input_shapes->items[i];
656     if (input_shape.num_dims == InferenceContext::kUnknownRank) {
657       c.SetInput(i, c.UnknownShape());
658       continue;
659     }
660     dims.reserve(input_shape.num_dims);
661     for (int j = 0; j < input_shape.num_dims; ++j) {
662       dims.push_back(c.MakeDim(input_shape.dims[j]));
663     }
664     c.SetInput(i, c.MakeShape(dims));
665   }
666 
667   // TODO(bgogul): Handle input_tensors_as_shapes.
668   // TODO(bgogul): Handle input_resource_shapes_and_types.
669 
670   status->status = c.construction_status();
671   if (!status->status.ok()) return;
672 
673   if (op_reg_data->shape_inference_fn == nullptr) {
674     status->status =
675         InvalidArgument("No shape inference function exists for op '",
676                         node_def.op(), "', did you forget to define it?");
677     return;
678   }
679 
680   status->status = c.Run(op_reg_data->shape_inference_fn);
681   if (!status->status.ok()) return;
682 
683   // Set output_shapes.
684   TF_ShapeAndTypeList* output_shapes_result =
685       TF_NewShapeAndTypeList(c.num_outputs());
686   for (int i = 0; i < c.num_outputs(); ++i) {
687     ShapeHandle shape_handle = c.output(i);
688     TF_ShapeAndType& shape = output_shapes_result->items[i];
689     shape.num_dims = c.Rank(shape_handle);
690     if (shape.num_dims == InferenceContext::kUnknownRank) {
691       shape.dims = nullptr;
692       continue;
693     }
694     shape.dims = new int64_t[shape.num_dims];
695     for (size_t j = 0; j < shape.num_dims; ++j) {
696       shape.dims[j] = c.Value(c.Dim(shape_handle, j));
697     }
698   }
699   if (output_shapes != nullptr) *output_shapes = output_shapes_result;
700 
701   // TODO(bgogul): Set output_resource_shapes_and_types.
702 }
703 
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)704 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
705     TF_ImportGraphDefOptions* opts, unsigned char enable) {
706   opts->opts.validate_colocation_constraints = enable;
707 }
708 
709 // Load a Pluggable Device library.
710 // On success, returns the handle to library in result and return OK from the
711 // function. Otherwise return nullptr in result and error Status from the
712 // function.
713 //
714 // If `library_filename` has already been loaded, we return a cached handle.
715 // Device and Kernels/Ops are registered as globals when a library is loaded
716 // for the first time.
TF_LoadPluggableDeviceLibrary(const char * library_filename,TF_Status * status)717 TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
718                                           TF_Status* status) {
719 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
720   status->status = tensorflow::errors::Unimplemented(
721       "PluggableDevice plugin functionality is not supported on mobile");
722   return nullptr;
723 #else
724   TF_Library* lib_handle = new TF_Library;
725   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
726   static std::unordered_map<std::string, void*>* loaded_libs =
727       new std::unordered_map<std::string, void*>();
728   tensorflow::Env* env = tensorflow::Env::Default();
729   {
730     tensorflow::mutex_lock lock(mu);
731     auto it = loaded_libs->find(library_filename);
732     if (it != loaded_libs->end()) {
733       lib_handle->lib_handle = it->second;
734     } else {
735       status->status =
736           env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
737       if (status->status.ok()) {
738         TF_CHECK_OK(
739             tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
740       } else {
741         delete lib_handle;
742         return nullptr;
743       }
744     }
745     return lib_handle;
746   }
747 #endif
748 }
749 
TF_DeletePluggableDeviceLibraryHandle(TF_Library * lib_handle)750 void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
751   delete lib_handle;
752 }
753