• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/c/eager/tfe_op_internal.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/c/tf_buffer_internal.h"
28 #include "tensorflow/compiler/jit/flags.h"
29 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
30 #include "tensorflow/core/common_runtime/eager/context.h"
31 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
32 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
33 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
34 #include "tensorflow/core/framework/collective.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/shape_inference.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/platform/blocking_counter.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/init_main.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/net.h"
46 #include "tensorflow/core/platform/platform.h"
47 #include "tensorflow/core/platform/strcat.h"
48 #include "tensorflow/core/protobuf/config.pb.h"
49 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
50 
51 using tensorflow::FunctionDef;
52 using tensorflow::Node;
53 using tensorflow::NodeBuilder;
54 using tensorflow::Status;
55 using tensorflow::errors::InvalidArgument;
56 
57 namespace {
58 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
59     UniqueFuncPtr;
60 }
61 
62 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)63 static TF_Operation* ToTF_Operation(Node* node) {
64   return static_cast<TF_Operation*>(static_cast<void*>(node));
65 }
66 
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)67 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
68   tensorflow::ConfigProto& config = options->options.config;
69   auto* optimizer_options =
70       config.mutable_graph_options()->mutable_optimizer_options();
71   if (enable) {
72     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
73 
74     // These XLA flags are needed to trigger XLA properly from C (more generally
75     // non-Python) clients. If this API is called again with `enable` set to
76     // false, it is safe to keep these flag values as is.
77     tensorflow::MarkForCompilationPassFlags* flags =
78         tensorflow::GetMarkForCompilationPassFlags();
79     flags->tf_xla_cpu_global_jit = true;
80     flags->tf_xla_min_cluster_size = 1;
81   } else {
82     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
83   }
84 }
85 
TF_SetXlaEnableLazyCompilation(unsigned char enable)86 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
87   tensorflow::BuildXlaOpsPassFlags* flags =
88       tensorflow::GetBuildXlaOpsPassFlags();
89   bool original = flags->tf_xla_enable_lazy_compilation;
90   flags->tf_xla_enable_lazy_compilation = enable;
91   return original;
92 }
93 
TF_SetTfXlaCpuGlobalJit(unsigned char enable)94 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
95   tensorflow::MarkForCompilationPassFlags* flags =
96       tensorflow::GetMarkForCompilationPassFlags();
97   bool original = flags->tf_xla_cpu_global_jit;
98   flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
99   return static_cast<unsigned char>(original);
100 }
101 
TF_SetXlaAutoJitMode(const char * mode)102 void TF_SetXlaAutoJitMode(const char* mode) {
103   tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
104 }
105 
TF_GetXlaAutoJitEnabled()106 unsigned char TF_GetXlaAutoJitEnabled() {
107   tensorflow::XlaAutoJitFlag flag =
108       tensorflow::GetMarkForCompilationPassFlags()->xla_auto_jit_flag;
109   return static_cast<unsigned char>(flag.optimization_level_single_gpu > 0 ||
110                                     flag.optimization_level_general > 0);
111 }
112 
TF_GetXlaConstantFoldingDisabled()113 unsigned char TF_GetXlaConstantFoldingDisabled() {
114   return static_cast<unsigned char>(
115       tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
116 }
117 
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)118 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
119   tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
120       static_cast<bool>(should_enable);
121 }
122 
TF_SetXlaMinClusterSize(int size)123 void TF_SetXlaMinClusterSize(int size) {
124   tensorflow::MarkForCompilationPassFlags* flags =
125       tensorflow::GetMarkForCompilationPassFlags();
126   flags->tf_xla_min_cluster_size = size;
127 }
128 
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)129 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
130                            unsigned char gpu_memory_allow_growth,
131                            unsigned int num_cpu_devices) {
132   tensorflow::ConfigProto config;
133   auto* optimizer_options =
134       config.mutable_graph_options()->mutable_optimizer_options();
135   if (enable_xla_compilation) {
136     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
137 
138     // These XLA flags are needed to trigger XLA properly from C (more generally
139     // non-Python) clients. If this API is called again with `enable` set to
140     // false, it is safe to keep these flag values as is.
141     tensorflow::MarkForCompilationPassFlags* flags =
142         tensorflow::GetMarkForCompilationPassFlags();
143     flags->tf_xla_cpu_global_jit = true;
144     flags->tf_xla_min_cluster_size = 1;
145   } else {
146     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
147   }
148 
149   auto* gpu_options = config.mutable_gpu_options();
150   gpu_options->set_allow_growth(gpu_memory_allow_growth);
151 
152   (*config.mutable_device_count())["CPU"] = num_cpu_devices;
153 
154   // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
155   // threadpool, so that we avoid the possibility of running the runner_ in the
156   // threadpool of GPU event mgr, as that can trigger more callbacks to be
157   // scheduled on that same threadpool, causing a deadlock in cases where the
158   // caller of event_mgr->ThenExecute() blocks on the completion of the callback
159   // (as in the case of ConstOp kernel creation on GPU, which involves copying a
160   // CPU tensor to GPU).
161   // Setting a larger thread pool does not help with the Swift caller, as we use
162   // a different TFE context for each thread of execution (for running graph
163   // functions, and their send/recvs corountines).
164   config.set_inter_op_parallelism_threads(1);
165 
166   TF_Buffer* ret = TF_NewBuffer();
167   TF_CHECK_OK(MessageToBuffer(config, ret));
168   return ret;
169 }
170 
TF_CreateRunOptions(unsigned char enable_full_trace)171 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
172   tensorflow::RunOptions options;
173   if (enable_full_trace) {
174     options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
175   } else {
176     options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
177   }
178   TF_Buffer* ret = TF_NewBuffer();
179   TF_CHECK_OK(MessageToBuffer(options, ret));
180   return ret;
181 }
182 
TF_GraphDebugString(TF_Graph * graph,size_t * len)183 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
184   tensorflow::mutex_lock c(graph->mu);
185   const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
186   *len = debug_str.size();
187   char* ret = static_cast<char*>(malloc(*len + 1));
188   memcpy(ret, debug_str.c_str(), *len + 1);
189   return ret;
190 }
191 
TF_FunctionDebugString(TF_Function * func,size_t * len)192 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
193   const auto& debug_str = DebugString(func->fdef);
194   *len = debug_str.size();
195   char* ret = static_cast<char*>(malloc(*len + 1));
196   memcpy(ret, debug_str.c_str(), *len + 1);
197   return ret;
198 }
199 
200 // On success, returns a set of TF_Function instances from `text_proto` of
201 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
202 //
203 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
204 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)205 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
206     const char* text_proto,
207     std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
208   tensorflow::GraphDef gdef;
209   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
210     status->status = tensorflow::errors::Internal(
211         "Invalid text proto for GraphDef: ", text_proto);
212     return {};
213   }
214   const auto& fdef_lib = gdef.library();
215   if (fdef_lib.gradient_size() > 0) {
216     status->status = tensorflow::errors::Internal(
217         "GradientDef is not supported in reading Dataset related functions: ",
218         text_proto);
219     return {};
220   }
221   std::vector<UniqueFuncPtr> ret;
222   for (const FunctionDef& fdef : fdef_lib.function()) {
223     // Make a copy so that we can mutate it.
224     FunctionDef fdef_to_load = fdef;
225     if (mutate_proto_func) {
226       (*mutate_proto_func)(&fdef_to_load);
227     }
228     VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
229     std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
230     fdef_to_load.SerializeToArray(binary_proto_buf.data(),
231                                   binary_proto_buf.size());
232     TF_Function* func = TF_FunctionImportFunctionDef(
233         binary_proto_buf.data(), binary_proto_buf.size(), status);
234     if (!status->status.ok()) return {};
235     ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
236   }
237   return ret;
238 }
239 
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)240 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
241                                  TF_Status* status) {
242   assert(session);
243   {
244     tensorflow::mutex_lock c(session->graph->mu);
245     VLOG(1) << "Dequeuing named tensor with id " << tensor_id
246             << ", with input graph: "
247             << session->graph->graph.ToGraphDefDebug().DebugString();
248   }
249 
250   TF_Operation* dequeue_op = TF_GraphOperationByName(
251       session->graph,
252       tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
253   if (dequeue_op == nullptr) {
254     status->status = tensorflow::errors::Internal(
255         "Unable to find the dequeue node in the TF graph.");
256     return nullptr;
257   }
258 
259   VLOG(1) << "Running the dequeue op";
260   TF_Output output{dequeue_op, 0};
261   TF_Tensor* ret;
262   TF_SessionRun(session, /*run_options*/ nullptr,
263                 // input related parameters
264                 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
265                 // output related parameters
266                 /*outputs*/ &output, /*output_values*/ &ret,
267                 /*noutputs*/ 1,
268                 /*targets*/ nullptr, /*ntargets*/ 0,
269                 /*run_metadata*/ nullptr, status);
270   if (VLOG_IS_ON(1) && status->status.ok()) {
271     tensorflow::Tensor tensor;
272     if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
273       VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
274     }
275   }
276   return ret;
277 }
278 
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)279 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
280                            TF_Tensor* tensor, TF_Status* status) {
281   assert(session);
282   {
283     tensorflow::mutex_lock c(session->graph->mu);
284     if (VLOG_IS_ON(1)) {
285       VLOG(1) << "Enqueuing named tensor with id " << tensor_id
286               << ", with input graph: "
287               << session->graph->graph.ToGraphDefDebug().DebugString();
288       tensorflow::Tensor internal_tensor;
289       if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
290         VLOG(1) << "Enqueu'ing tensor content: "
291                 << internal_tensor.DebugString();
292       }
293     }
294   }
295 
296   TF_Operation* enqueue_op = TF_GraphOperationByName(
297       session->graph,
298       tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
299   if (enqueue_op == nullptr) {
300     status->status = tensorflow::errors::Internal(
301         "Unable to find the enqueue node in the TF graph.");
302     return;
303   }
304 
305   TF_Operation* placeholder_op = TF_GraphOperationByName(
306       session->graph,
307       tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
308   if (placeholder_op == nullptr) {
309     status->status = tensorflow::errors::Internal(
310         "Unable to find the placeholder node as input to enqueue in the TF "
311         "graph.");
312     return;
313   }
314 
315   VLOG(1) << "Running the enqueue op";
316   TF_Output input{placeholder_op, 0};
317   TF_SessionRun(session, /*run_options*/ nullptr,
318                 // input related parameters
319                 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
320                 // output related parameters
321                 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
322                 /*targets*/ &enqueue_op, /*ntargets*/ 1,
323                 /*run_metadata*/ nullptr, status);
324   VLOG(1) << "Enqueuing is done.";
325 }
326 
TFE_GetServerDef(const char * text_proto,TF_Status * status)327 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
328   tensorflow::ServerDef server_def;
329   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
330                                                          &server_def)) {
331     status->status = tensorflow::errors::Internal(
332         "Invalid text proto for ServerDef: ", text_proto);
333     return nullptr;
334   }
335   status->status = tensorflow::Status();
336   TF_Buffer* ret = TF_NewBuffer();
337   TF_CHECK_OK(MessageToBuffer(server_def, ret));
338   return ret;
339 }
340 
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)341 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
342   status->status = tensorflow::errors::Internal(errMsg);
343 }
344 
345 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
346   using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
347   std::vector<std::string> variable_list;
348 };
349 
TF_NewCheckpointReader(const char * filename,TF_Status * status)350 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
351                                             TF_Status* status) {
352   TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
353   if (!status->status.ok()) {
354     TF_DeleteCheckpointReader(reader);
355     return nullptr;
356   }
357   const auto& m = reader->GetVariableToDataTypeMap();
358   for (auto it = m.begin(); it != m.end(); ++it)
359     reader->variable_list.push_back(it->first);
360   std::sort(reader->variable_list.begin(), reader->variable_list.end());
361   return reader;
362 }
363 
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)364 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
365 
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)366 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
367                                  const char* name) {
368   return reader->HasTensor(name);
369 }
370 
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)371 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
372                                            int index) {
373   return reader->variable_list[index].c_str();
374 }
375 
TF_CheckpointReaderSize(TF_CheckpointReader * reader)376 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
377   return reader->variable_list.size();
378 }
379 
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)380 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
381                                                    const char* name) {
382   const auto& m = reader->GetVariableToDataTypeMap();
383   return static_cast<TF_DataType>(m.at(name));
384 }
385 
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)386 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
387                                         const char* name, TF_Status* status) {
388   std::unique_ptr<tensorflow::Tensor> tensor;
389   reader->GetTensor(name, &tensor, status);
390   if (!status->status.ok()) return nullptr;
391   return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
392 }
393 
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)394 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
395                                          const char* name, int64_t* dims,
396                                          int num_dims, TF_Status* status) {
397   const auto& shape = reader->GetVariableToShapeMap().at(name);
398   int rank = shape.dims();
399   if (num_dims != rank) {
400     status->status = InvalidArgument("Expected rank is ", num_dims,
401                                      " but actual rank is ", rank);
402     return;
403   }
404   for (int i = 0; i < num_dims; i++) {
405     dims[i] = shape.dim_size(i);
406   }
407 }
408 
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)409 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
410                                           const char* name) {
411   const auto& m = reader->GetVariableToShapeMap();
412   return m.at(name).dims();
413 }
414 
415 // This builder is used in the eager API to build a NodeDef.
416 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
417   using tensorflow::AttrBuilder::AttrBuilder;
418   // The string buffers to make sure that any `attr_name` we pass into
419   // `builder->Set()` will outlive the subsequent
420   // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
421   std::set<std::string> attr_names;
422 };
423 
TF_NewAttrBuilder(const char * op_name)424 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
425   return new TF_AttrBuilder(op_name);
426 }
427 
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)428 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
429 
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)430 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
431                            TF_DataType value) {
432   auto iter = builder->attr_names.insert(attr_name).first;
433   builder->Set(*iter, static_cast<tensorflow::DataType>(value));
434 }
435 
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)436 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
437                                const TF_DataType* values, int num_values) {
438   auto iter = builder->attr_names.insert(attr_name).first;
439   builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
440                           reinterpret_cast<const tensorflow::DataType*>(values),
441                           num_values));
442 }
443 
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)444 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
445                                        const char* device_type,
446                                        TF_Status* status) {
447   status->status = tensorflow::FindKernelDef(
448       tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
449       /* def = */ nullptr, /* kernel_class_name = */ nullptr);
450 }
451 
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)452 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
453                                            TF_Status* status) {
454   const tensorflow::OpDef* op_def = nullptr;
455   status->status =
456       tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
457   if (!status->status.ok()) return nullptr;
458 
459   if (input_index >= op_def->input_arg_size() || input_index < 0) {
460     status->status = tensorflow::errors::InvalidArgument(
461         input_index, " out of range for ", op_name);
462     return nullptr;
463   }
464 
465   const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
466 
467   if (input_arg.number_attr().empty()) {
468     status->status = tensorflow::errors::NotFound(
469         op_name, " does not have number_attr() defined.");
470     return nullptr;
471   }
472 
473   // The returned string is owned by OpRegistry, so liveness is not a concern.
474   return input_arg.number_attr().c_str();
475 }
476 
TF_OpIsStateful(const char * op_type,TF_Status * status)477 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
478   const tensorflow::OpRegistrationData* op_reg_data;
479   status->status =
480       tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
481   if (!status->status.ok()) {
482     return 0;
483   }
484   return op_reg_data->op_def.is_stateful();
485 }
486 
TF_InitMain(const char * usage,int * argc,char *** argv)487 void TF_InitMain(const char* usage, int* argc, char*** argv) {
488   tensorflow::port::InitMain(usage, argc, argv);
489 }
490 
TF_PickUnusedPortOrDie()491 int TF_PickUnusedPortOrDie() {
492   return tensorflow::internal::PickUnusedPortOrDie();
493 }
494 
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)495 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
496                                                 void* data, size_t len,
497                                                 TF_Status* status) {
498   auto dtype = static_cast<tensorflow::DataType>(data_type);
499   DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
500 
501   tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
502   std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
503 
504   status->status = ::tensorflow::OkStatus();
505   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
506 }
507 
508 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)509 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
510                                                    const void* proto,
511                                                    size_t proto_len,
512                                                    TF_Status* status) {
513   tensorflow::ServerDef server_def;
514   if (!server_def.ParseFromArray(proto, proto_len)) {
515     status->status = tensorflow::errors::InvalidArgument(
516         "Invalid tensorflow.ServerDef protocol buffer");
517     return;
518   }
519   status->status = tensorflow::unwrap(ctx)->EnableCollectiveOps(server_def);
520 }
521 
TFE_AbortCollectiveOps(TFE_Context * ctx,TF_Status * status)522 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
523                                                   TF_Status* status) {
524   tensorflow::EagerContext* context =
525       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
526   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
527   collective_executor_handle->get()->StartAbort(status->status);
528 }
529 
TFE_CollectiveOpsCheckPeerHealth(TFE_Context * ctx,const char * task,int64_t timeout_in_ms,TF_Status * status)530 TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
531     TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
532     TF_Status* status) {
533   tensorflow::EagerContext* context =
534       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
535   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
536   tensorflow::Notification done;
537   collective_executor_handle->get()->remote_access()->CheckPeerHealth(
538       task, timeout_in_ms, [&done, status](const Status& s) {
539         status->status = s;
540         done.Notify();
541       });
542   done.WaitForNotification();
543 }
544 
TF_NewShapeAndTypeList(int num_items)545 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
546   TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
547   result->num_items = num_items;
548   result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
549   return result;
550 }
551 
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)552 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
553                                  const int64_t* dims, int num_dims) {
554   DCHECK(index >= 0 && index < shape_list->num_items);
555   TF_ShapeAndType& shape = shape_list->items[index];
556   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
557   DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
558   shape.num_dims = num_dims;
559   shape.dims = new int64_t[num_dims];
560   memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
561 }
562 
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)563 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
564                                         int index) {
565   DCHECK(index >= 0 && index < shape_list->num_items);
566   TF_ShapeAndType& shape = shape_list->items[index];
567   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
568   shape.num_dims = -1;
569   shape.dims = nullptr;
570 }
571 
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)572 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
573                                  TF_DataType dtype) {
574   DCHECK(index >= 0 && index < shape_list->num_items);
575   TF_ShapeAndType& shape_and_type = shape_list->items[index];
576   shape_and_type.dtype = dtype;
577 }
578 
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)579 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
580   if (shape_list == nullptr) return;
581   for (size_t i = 0; i < shape_list->num_items; ++i) {
582     delete[] shape_list->items[i].dims;
583   }
584   delete[] shape_list->items;
585   delete shape_list;
586 }
587 
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)588 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
589                                     int num_items) {
590   if (shape_list_array == nullptr) return;
591   for (int i = 0; i < num_items; ++i) {
592     TF_DeleteShapeAndTypeList(shape_list_array[i]);
593   }
594   delete[] shape_list_array;
595 }
596 
597 namespace tensorflow {
598 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
599 
600 // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
601 Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
602 }  // namespace tensorflow
603 
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)604 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
605                      TF_Tensor** input_tensors,
606                      TF_ShapeAndTypeList* input_tensors_as_shapes,
607                      TF_ShapeAndTypeList** input_resource_shapes_and_types,
608                      TF_ShapeAndTypeList** output_shapes,
609                      TF_ShapeAndTypeList*** output_resource_shapes_and_types,
610                      TF_Status* status) {
611   using tensorflow::NodeDef;
612   using tensorflow::OpRegistrationData;
613   using tensorflow::Tensor;
614   using tensorflow::shape_inference::DimensionHandle;
615   using tensorflow::shape_inference::InferenceContext;
616   using tensorflow::shape_inference::ShapeAndType;
617   using tensorflow::shape_inference::ShapeHandle;
618 
619   const int num_inputs = input_shapes->num_items;
620   NodeDef node_def;
621   tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
622   node_def.set_name(op->Name());
623   node_def.set_op(op->Name());
624   for (int i = 0; i < num_inputs; ++i) {
625     node_def.add_input("dummy_input");
626   }
627   OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
628 
629   const tensorflow::OpRegistrationData* op_reg_data;
630   status->status =
631       tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
632   if (!status->status.ok()) return;
633 
634   // Initialize a input_tensor vector with `nullptr` values.
635   std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
636   // A vector to keep track of newly created `tf::Tensor` objects.
637   std::vector<Tensor> all_input_tensors;
638   // Update the vector with information from `input_tensors` if provided.
639   if (input_tensors != nullptr) {
640     // Note that we take the address of the elements in `all_input_tensors`
641     // below. Allocate enough space so that no reallocation happens, which will
642     // make the pointers invalid.
643     all_input_tensors.reserve(num_inputs);
644     for (int i = 0; i < num_inputs; ++i) {
645       if (input_tensors[i] == nullptr) continue;
646       all_input_tensors.emplace_back();
647       Tensor& input_tensor = all_input_tensors.back();
648       status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
649       if (!status->status.ok()) return;
650       input_tensors_vector[i] = &input_tensor;
651     }
652   }
653 
654   // Create an inference context with dummy values, which will be updated later.
655   InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
656                      std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
657                      {},
658                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
659 
660   // Set input_shapes.
661   for (int i = 0; i < num_inputs; ++i) {
662     std::vector<DimensionHandle> dims;
663     const TF_ShapeAndType& input_shape = input_shapes->items[i];
664     if (input_shape.num_dims == InferenceContext::kUnknownRank) {
665       c.SetInput(i, c.UnknownShape());
666       continue;
667     }
668     dims.reserve(input_shape.num_dims);
669     for (int j = 0; j < input_shape.num_dims; ++j) {
670       dims.push_back(c.MakeDim(input_shape.dims[j]));
671     }
672     c.SetInput(i, c.MakeShape(dims));
673   }
674 
675   // TODO(bgogul): Handle input_tensors_as_shapes.
676   // TODO(bgogul): Handle input_resource_shapes_and_types.
677 
678   status->status = c.construction_status();
679   if (!status->status.ok()) return;
680 
681   if (op_reg_data->shape_inference_fn == nullptr) {
682     status->status =
683         InvalidArgument("No shape inference function exists for op '",
684                         node_def.op(), "', did you forget to define it?");
685     return;
686   }
687 
688   status->status = c.Run(op_reg_data->shape_inference_fn);
689   if (!status->status.ok()) return;
690 
691   // Set output_shapes.
692   TF_ShapeAndTypeList* output_shapes_result =
693       TF_NewShapeAndTypeList(c.num_outputs());
694   for (int i = 0; i < c.num_outputs(); ++i) {
695     ShapeHandle shape_handle = c.output(i);
696     TF_ShapeAndType& shape = output_shapes_result->items[i];
697     shape.num_dims = c.Rank(shape_handle);
698     if (shape.num_dims == InferenceContext::kUnknownRank) {
699       shape.dims = nullptr;
700       continue;
701     }
702     shape.dims = new int64_t[shape.num_dims];
703     for (size_t j = 0; j < shape.num_dims; ++j) {
704       shape.dims[j] = c.Value(c.Dim(shape_handle, j));
705     }
706   }
707   if (output_shapes != nullptr) *output_shapes = output_shapes_result;
708 
709   // TODO(bgogul): Set output_resource_shapes_and_types.
710 }
711 
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)712 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
713     TF_ImportGraphDefOptions* opts, unsigned char enable) {
714   opts->opts.validate_colocation_constraints = enable;
715 }
716 
717 // Load a Pluggable Device library.
718 // On success, returns the handle to library in result and return OK from the
719 // function. Otherwise return nullptr in result and error Status from the
720 // function.
721 //
722 // If `library_filename` has already been loaded, we return a cached handle.
723 // Device and Kernels/Ops are registered as globals when a library is loaded
724 // for the first time.
TF_LoadPluggableDeviceLibrary(const char * library_filename,TF_Status * status)725 TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
726                                           TF_Status* status) {
727 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
728   status->status = tensorflow::errors::Unimplemented(
729       "PluggableDevice plugin functionality is not supported on mobile");
730   return nullptr;
731 #else
732   TF_Library* lib_handle = new TF_Library;
733   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
734   static std::unordered_map<std::string, void*>* loaded_libs =
735       new std::unordered_map<std::string, void*>();
736   tensorflow::Env* env = tensorflow::Env::Default();
737   {
738     tensorflow::mutex_lock lock(mu);
739     auto it = loaded_libs->find(library_filename);
740     if (it != loaded_libs->end()) {
741       lib_handle->lib_handle = it->second;
742     } else {
743       status->status =
744           env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
745       if (status->status.ok()) {
746         TF_CHECK_OK(
747             tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
748       } else {
749         delete lib_handle;
750         return nullptr;
751       }
752     }
753     return lib_handle;
754   }
755 #endif
756 }
757 
TF_DeletePluggableDeviceLibraryHandle(TF_Library * lib_handle)758 void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
759   delete lib_handle;
760 }
761