• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/shape_inference.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/init_main.h"
35 #include "tensorflow/core/platform/net.h"
36 #include "tensorflow/core/platform/platform.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
39 
40 using tensorflow::FunctionDef;
41 using tensorflow::Node;
42 using tensorflow::NodeBuilder;
43 using tensorflow::Status;
44 using tensorflow::errors::InvalidArgument;
45 
46 namespace {
47 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
48     UniqueFuncPtr;
49 }
50 
51 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)52 static TF_Operation* ToTF_Operation(Node* node) {
53   return static_cast<TF_Operation*>(static_cast<void*>(node));
54 }
55 
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)56 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
57   tensorflow::ConfigProto& config = options->options.config;
58   auto* optimizer_options =
59       config.mutable_graph_options()->mutable_optimizer_options();
60   if (enable) {
61     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
62 
63     // These XLA flags are needed to trigger XLA properly from C (more generally
64     // non-Python) clients. If this API is called again with `enable` set to
65     // false, it is safe to keep these flag values as is.
66     tensorflow::MarkForCompilationPassFlags* flags =
67         tensorflow::GetMarkForCompilationPassFlags();
68     flags->tf_xla_cpu_global_jit = true;
69     flags->tf_xla_min_cluster_size = 1;
70   } else {
71     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
72   }
73 }
74 
TF_SetXlaEnableLazyCompilation(unsigned char enable)75 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
76   tensorflow::BuildXlaOpsPassFlags* flags =
77       tensorflow::GetBuildXlaOpsPassFlags();
78   bool original = flags->tf_xla_enable_lazy_compilation;
79   flags->tf_xla_enable_lazy_compilation = enable;
80   return original;
81 }
82 
TF_SetTfXlaCpuGlobalJit(unsigned char enable)83 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
84   tensorflow::MarkForCompilationPassFlags* flags =
85       tensorflow::GetMarkForCompilationPassFlags();
86   bool original = flags->tf_xla_cpu_global_jit;
87   flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
88   return static_cast<unsigned char>(original);
89 }
90 
TF_SetXlaAutoJitMode(const char * mode)91 void TF_SetXlaAutoJitMode(const char* mode) {
92   tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
93 }
94 
TF_GetXlaConstantFoldingDisabled()95 unsigned char TF_GetXlaConstantFoldingDisabled() {
96   return static_cast<unsigned char>(
97       tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
98 }
99 
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)100 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
101   tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
102       static_cast<bool>(should_enable);
103 }
104 
TF_SetXlaMinClusterSize(int size)105 void TF_SetXlaMinClusterSize(int size) {
106   tensorflow::MarkForCompilationPassFlags* flags =
107       tensorflow::GetMarkForCompilationPassFlags();
108   flags->tf_xla_min_cluster_size = size;
109 }
110 
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)111 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
112                            unsigned char gpu_memory_allow_growth,
113                            unsigned int num_cpu_devices) {
114   tensorflow::ConfigProto config;
115   auto* optimizer_options =
116       config.mutable_graph_options()->mutable_optimizer_options();
117   if (enable_xla_compilation) {
118     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
119 
120     // These XLA flags are needed to trigger XLA properly from C (more generally
121     // non-Python) clients. If this API is called again with `enable` set to
122     // false, it is safe to keep these flag values as is.
123     tensorflow::MarkForCompilationPassFlags* flags =
124         tensorflow::GetMarkForCompilationPassFlags();
125     flags->tf_xla_cpu_global_jit = true;
126     flags->tf_xla_min_cluster_size = 1;
127   } else {
128     optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
129   }
130 
131   auto* gpu_options = config.mutable_gpu_options();
132   gpu_options->set_allow_growth(gpu_memory_allow_growth);
133 
134   (*config.mutable_device_count())["CPU"] = num_cpu_devices;
135 
136   // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
137   // threadpool, so that we avoid the possibility of running the runner_ in the
138   // threadpool of GPU event mgr, as that can trigger more callbacks to be
139   // scheduled on that same threadpool, causing a deadlock in cases where the
140   // caller of event_mgr->ThenExecute() blocks on the completion of the callback
141   // (as in the case of ConstOp kernel creation on GPU, which involves copying a
142   // CPU tensor to GPU).
143   // Setting a larger thread pool does not help with the Swift caller, as we use
144   // a different TFE context for each thread of execution (for running graph
145   // functions, and their send/recvs corountines).
146   config.set_inter_op_parallelism_threads(1);
147 
148   TF_Buffer* ret = TF_NewBuffer();
149   TF_CHECK_OK(MessageToBuffer(config, ret));
150   return ret;
151 }
152 
TF_CreateRunOptions(unsigned char enable_full_trace)153 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
154   tensorflow::RunOptions options;
155   if (enable_full_trace) {
156     options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
157   } else {
158     options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
159   }
160   TF_Buffer* ret = TF_NewBuffer();
161   TF_CHECK_OK(MessageToBuffer(options, ret));
162   return ret;
163 }
164 
TF_GraphDebugString(TF_Graph * graph,size_t * len)165 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
166   tensorflow::mutex_lock c(graph->mu);
167   const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
168   *len = debug_str.size();
169   char* ret = static_cast<char*>(malloc(*len + 1));
170   memcpy(ret, debug_str.c_str(), *len + 1);
171   return ret;
172 }
173 
TF_FunctionDebugString(TF_Function * func,size_t * len)174 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
175   const auto& debug_str = DebugString(func->fdef);
176   *len = debug_str.size();
177   char* ret = static_cast<char*>(malloc(*len + 1));
178   memcpy(ret, debug_str.c_str(), *len + 1);
179   return ret;
180 }
181 
182 // On success, returns a set of TF_Function instances from `text_proto` of
183 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
184 //
185 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
186 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)187 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
188     const char* text_proto,
189     std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
190   tensorflow::GraphDef gdef;
191   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
192     status->status = tensorflow::errors::Internal(
193         "Invalid text proto for GraphDef: ", text_proto);
194     return {};
195   }
196   const auto& fdef_lib = gdef.library();
197   if (fdef_lib.gradient_size() > 0) {
198     status->status = tensorflow::errors::Internal(
199         "GradientDef is not supported in reading Dataset related functions: ",
200         text_proto);
201     return {};
202   }
203   std::vector<UniqueFuncPtr> ret;
204   for (const FunctionDef& fdef : fdef_lib.function()) {
205     // Make a copy so that we can mutate it.
206     FunctionDef fdef_to_load = fdef;
207     if (mutate_proto_func) {
208       (*mutate_proto_func)(&fdef_to_load);
209     }
210     VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
211     std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
212     fdef_to_load.SerializeToArray(binary_proto_buf.data(),
213                                   binary_proto_buf.size());
214     TF_Function* func = TF_FunctionImportFunctionDef(
215         binary_proto_buf.data(), binary_proto_buf.size(), status);
216     if (!status->status.ok()) return {};
217     ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
218   }
219   return ret;
220 }
221 
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)222 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
223                                  TF_Status* status) {
224   assert(session);
225   {
226     tensorflow::mutex_lock c(session->graph->mu);
227     VLOG(1) << "Dequeuing named tensor with id " << tensor_id
228             << ", with input graph: "
229             << session->graph->graph.ToGraphDefDebug().DebugString();
230   }
231 
232   TF_Operation* dequeue_op = TF_GraphOperationByName(
233       session->graph,
234       tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
235   if (dequeue_op == nullptr) {
236     status->status = tensorflow::errors::Internal(
237         "Unable to find the dequeue node in the TF graph.");
238     return nullptr;
239   }
240 
241   VLOG(1) << "Running the dequeue op";
242   TF_Output output{dequeue_op, 0};
243   TF_Tensor* ret;
244   TF_SessionRun(session, /*run_options*/ nullptr,
245                 // input related parameters
246                 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
247                 // output related parameters
248                 /*outputs*/ &output, /*output_values*/ &ret,
249                 /*noutputs*/ 1,
250                 /*targets*/ nullptr, /*ntargets*/ 0,
251                 /*run_metadata*/ nullptr, status);
252   if (VLOG_IS_ON(1) && status->status.ok()) {
253     tensorflow::Tensor tensor;
254     if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
255       VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
256     }
257   }
258   return ret;
259 }
260 
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)261 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
262                            TF_Tensor* tensor, TF_Status* status) {
263   assert(session);
264   {
265     tensorflow::mutex_lock c(session->graph->mu);
266     if (VLOG_IS_ON(1)) {
267       VLOG(1) << "Enqueuing named tensor with id " << tensor_id
268               << ", with input graph: "
269               << session->graph->graph.ToGraphDefDebug().DebugString();
270       tensorflow::Tensor internal_tensor;
271       if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
272         VLOG(1) << "Enqueu'ing tensor content: "
273                 << internal_tensor.DebugString();
274       }
275     }
276   }
277 
278   TF_Operation* enqueue_op = TF_GraphOperationByName(
279       session->graph,
280       tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
281   if (enqueue_op == nullptr) {
282     status->status = tensorflow::errors::Internal(
283         "Unable to find the enqueue node in the TF graph.");
284     return;
285   }
286 
287   TF_Operation* placeholder_op = TF_GraphOperationByName(
288       session->graph,
289       tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
290   if (placeholder_op == nullptr) {
291     status->status = tensorflow::errors::Internal(
292         "Unable to find the placeholder node as input to enqueue in the TF "
293         "graph.");
294     return;
295   }
296 
297   VLOG(1) << "Running the enqueue op";
298   TF_Output input{placeholder_op, 0};
299   TF_SessionRun(session, /*run_options*/ nullptr,
300                 // input related parameters
301                 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
302                 // output related parameters
303                 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
304                 /*targets*/ &enqueue_op, /*ntargets*/ 1,
305                 /*run_metadata*/ nullptr, status);
306   VLOG(1) << "Enqueuing is done.";
307 }
308 
TFE_GetServerDef(const char * text_proto,TF_Status * status)309 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
310   tensorflow::ServerDef server_def;
311   if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
312                                                          &server_def)) {
313     status->status = tensorflow::errors::Internal(
314         "Invalid text proto for ServerDef: ", text_proto);
315     return nullptr;
316   }
317   status->status = tensorflow::Status();
318   TF_Buffer* ret = TF_NewBuffer();
319   TF_CHECK_OK(MessageToBuffer(server_def, ret));
320   return ret;
321 }
322 
TFE_CreateContextFromSession(TF_Session * session,TF_Status * status)323 TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
324                                           TF_Status* status) {
325   auto* opts = TFE_NewContextOptions();
326 
327   // Reduce GPU memory allocation, and set appropriate config options for TFE
328   // context.
329   auto* config = TF_CreateConfig(
330       /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
331       10);
332   TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
333   if (!status->status.ok()) {
334     CHECK(!config);
335     TFE_DeleteContextOptions(opts);
336     return nullptr;
337   }
338 
339   auto* ctx = TFE_NewContextFromSession(opts, session, status);
340   TF_DeleteBuffer(config);
341   TFE_DeleteContextOptions(opts);
342   return ctx;
343 }
344 
345 // TODO: retrieve the device string via TFE_ContextListDevices()
346 static const char DEFAULT_CPU_DEVICE[] =
347     "/job:localhost/replica:0/task:0/device:CPU:0";
348 
createTFEQueue(TFE_Context * ctx,TF_DataType inputType,int tensor_id,TF_Status * status)349 static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
350                                         int tensor_id, TF_Status* status) {
351   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
352       TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
353   TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
354   if (!status->status.ok()) return nullptr;
355   // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
356   TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
357   TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
358   auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
359   TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
360                       shared_name.size());
361   TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
362 
363   // TODO: consider making this an unknown shape.
364   const int64_t* dims_ptr = nullptr;
365   int num_dims = 0;
366   TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
367                          /*num_values*/ 0, status);
368   if (!status->status.ok()) return nullptr;
369 
370   int num_retvals = 1;
371   TFE_TensorHandle* queue = nullptr;
372   TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
373   if (!status->status.ok()) return nullptr;
374   CHECK_EQ(num_retvals, 1);
375 
376   return queue;
377 }
378 
createTFEEnqueue(TFE_Context * ctx,TF_DataType inputType,TFE_TensorHandle * queue,TFE_TensorHandle * tensor,TF_Status * status)379 static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
380                              TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
381                              TF_Status* status) {
382   TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
383   if (!status->status.ok()) return;
384   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
385   TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
386   if (!status->status.ok()) return;
387   TFE_OpAddInput(op, queue, status);
388   if (!status->status.ok()) return;
389   TFE_OpAddInput(op, tensor, status);
390   if (!status->status.ok()) return;
391   TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
392   TFE_OpSetAttrInt(op, "timeout_ms", -1);
393 
394   int num_retvals = 0;
395   TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
396   if (!status->status.ok()) return;
397   CHECK_EQ(num_retvals, 0);
398 }
399 
createTFEDequeue(TFE_Context * ctx,TF_DataType inputType,TFE_TensorHandle * queue,TF_Status * status)400 static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
401                                           TF_DataType inputType,
402                                           TFE_TensorHandle* queue,
403                                           TF_Status* status) {
404   TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
405   if (!status->status.ok()) return nullptr;
406   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
407   TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
408   if (!status->status.ok()) return nullptr;
409 
410   TFE_OpAddInput(op, queue, status);
411   if (!status->status.ok()) return nullptr;
412   TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
413   TFE_OpSetAttrInt(op, "timeout_ms", -1);
414   TFE_TensorHandle* ret;
415   int num_retvals = 1;
416   TFE_Execute(op, &ret, &num_retvals, status);
417   if (!status->status.ok()) return nullptr;
418   CHECK_EQ(num_retvals, 1);
419   return ret;
420 }
421 
TFE_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_DataType inputType,TF_Status * status)422 TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
423                                          TF_DataType inputType,
424                                          TF_Status* status) {
425   assert(session);
426   VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
427 
428   auto ctx = TFE_CreateContextFromSession(session, status);
429   if (!status->status.ok()) return nullptr;
430   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
431       ctx, TFE_DeleteContext);
432 
433   TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
434   if (!status->status.ok()) return nullptr;
435   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
436       queue_deleter(queue, TFE_DeleteTensorHandle);
437 
438   auto* ret = createTFEDequeue(ctx, inputType, queue, status);
439   return ret;
440 }
441 
TFE_DequeueNamedTensorFromCtx(TFE_Context * ctx,int tensor_id,TF_DataType inputType,TF_Status * status)442 TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
443                                                 TF_DataType inputType,
444                                                 TF_Status* status) {
445   TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
446   if (!status->status.ok()) return nullptr;
447   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
448       queue_deleter(queue, TFE_DeleteTensorHandle);
449 
450   auto* ret = createTFEDequeue(ctx, inputType, queue, status);
451 
452   return ret;
453 }
454 
TFE_EnqueueNamedTensor(TF_Session * session,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)455 void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
456                             TFE_TensorHandle* tensor, TF_Status* status) {
457   assert(session);
458   VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
459 
460   auto ctx = TFE_CreateContextFromSession(session, status);
461   if (!status->status.ok()) return;
462   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
463       ctx, TFE_DeleteContext);
464 
465   TF_DataType inputType = TFE_TensorHandleDataType(tensor);
466   TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
467   if (!status->status.ok()) return;
468   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
469       queue_deleter(queue, TFE_DeleteTensorHandle);
470 
471   createTFEEnqueue(ctx, inputType, queue, tensor, status);
472 }
473 
TFE_EnqueueNamedTensorFromCtx(TFE_Context * ctx,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)474 void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
475                                    TFE_TensorHandle* tensor,
476                                    TF_Status* status) {
477   VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
478 
479   TF_DataType inputType = TFE_TensorHandleDataType(tensor);
480   TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
481   if (!status->status.ok()) return;
482   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
483       queue_deleter(queue, TFE_DeleteTensorHandle);
484 
485   createTFEEnqueue(ctx, inputType, queue, tensor, status);
486 }
487 
TFE_EnqueueVariantTensor(TF_Session * session,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)488 void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
489                               TFE_TensorHandle* tensor, TF_Status* status) {
490   VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
491 
492   auto ctx = TFE_CreateContextFromSession(session, status);
493   if (!status->status.ok()) return;
494   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
495       ctx, TFE_DeleteContext);
496 
497   TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
498   if (!status->status.ok()) return;
499   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
500       queue_deleter(queue, TFE_DeleteTensorHandle);
501 
502   createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
503 }
504 
TFE_DequeueVariantTensor(TF_Session * session,int tensor_id,TF_Status * status)505 TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
506                                            TF_Status* status) {
507   VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
508 
509   auto ctx = TFE_CreateContextFromSession(session, status);
510   if (!status->status.ok()) return nullptr;
511   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
512       ctx, TFE_DeleteContext);
513 
514   TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
515   if (!status->status.ok()) return nullptr;
516   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
517       queue_deleter(queue, TFE_DeleteTensorHandle);
518 
519   return createTFEDequeue(ctx, TF_VARIANT, queue, status);
520 }
521 
TFE_TensorHandlePrintDebugString(TFE_TensorHandle * handle)522 void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
523   auto* status = TF_NewStatus();
524   TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
525   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
526 
527   tensorflow::Tensor dst;
528   TF_CHECK_OK(TF_TensorToTensor(t, &dst));
529   LOG(INFO) << dst.DebugString();
530 
531   TF_DeleteTensor(t);
532   TF_DeleteStatus(status);
533 }
534 
TFE_OpPrintDebugString(TFE_Op * op)535 void TFE_OpPrintDebugString(TFE_Op* op) {
536   VLOG(1) << "TFE_OpPrintDebugString() over " << op;
537   LOG(INFO) << op->operation.DebugString();
538 }
539 
540 struct TFE_ExecuteOpNotification {
TFE_ExecuteOpNotificationTFE_ExecuteOpNotification541   TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
542   tensorflow::Notification n;
543   std::unique_ptr<tensorflow::Thread> thread;
544   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
545 };
546 
TFE_ExecuteOpInNewThread(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)547 TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
548                                                     TFE_TensorHandle** retvals,
549                                                     int* num_retvals,
550                                                     TF_Status* status) {
551   TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
552 
553   n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
554       tensorflow::ThreadOptions(), "ExecuteOpThread",
555       [op, retvals, num_retvals, n]() {
556         TFE_Execute(op, retvals, num_retvals, n->status.get());
557         n->n.Notify();
558       }));
559 
560   return n;
561 }
562 
TFE_ExecuteOpNotificationWaitAndDelete(TFE_ExecuteOpNotification * notification,TF_Status * status)563 void TFE_ExecuteOpNotificationWaitAndDelete(
564     TFE_ExecuteOpNotification* notification, TF_Status* status) {
565   if (notification == nullptr) {
566     status->status = tensorflow::errors::InvalidArgument(
567         "Passed in notification is a nullptr.");
568 
569     return;
570   }
571   if (notification->thread == nullptr) {
572     status->status = tensorflow::errors::InvalidArgument(
573         "Passed in notification didn't start a thread correctly. Cleaning up "
574         "this notification. Please re-execute the operation to get a new "
575         "notification.");
576 
577     delete notification;
578     return;
579   }
580 
581   notification->n.WaitForNotification();
582 
583   status->status = notification->status->status;
584 
585   delete notification;
586 }
587 
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)588 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
589   status->status = tensorflow::errors::Internal(errMsg);
590 }
591 
592 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
593   using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
594   std::vector<std::string> variable_list;
595 };
596 
TF_NewCheckpointReader(const char * filename,TF_Status * status)597 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
598                                             TF_Status* status) {
599   TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
600   if (!status->status.ok()) {
601     TF_DeleteCheckpointReader(reader);
602     return nullptr;
603   }
604   const auto& m = reader->GetVariableToDataTypeMap();
605   for (auto it = m.begin(); it != m.end(); ++it)
606     reader->variable_list.push_back(it->first);
607   std::sort(reader->variable_list.begin(), reader->variable_list.end());
608   return reader;
609 }
610 
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)611 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
612 
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)613 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
614                                  const char* name) {
615   return reader->HasTensor(name);
616 }
617 
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)618 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
619                                            int index) {
620   return reader->variable_list[index].c_str();
621 }
622 
TF_CheckpointReaderSize(TF_CheckpointReader * reader)623 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
624   return reader->variable_list.size();
625 }
626 
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)627 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
628                                                    const char* name) {
629   const auto& m = reader->GetVariableToDataTypeMap();
630   return static_cast<TF_DataType>(m.at(name));
631 }
632 
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)633 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
634                                         const char* name, TF_Status* status) {
635   std::unique_ptr<tensorflow::Tensor> tensor;
636   reader->GetTensor(name, &tensor, status);
637   if (!status->status.ok()) return nullptr;
638   return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
639 }
640 
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)641 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
642                                          const char* name, int64_t* dims,
643                                          int num_dims, TF_Status* status) {
644   const auto& shape = reader->GetVariableToShapeMap().at(name);
645   int rank = shape.dims();
646   if (num_dims != rank) {
647     status->status = InvalidArgument("Expected rank is ", num_dims,
648                                      " but actual rank is ", rank);
649     return;
650   }
651   for (int i = 0; i < num_dims; i++) {
652     dims[i] = shape.dim_size(i);
653   }
654 }
655 
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)656 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
657                                           const char* name) {
658   const auto& m = reader->GetVariableToShapeMap();
659   return m.at(name).dims();
660 }
661 
662 // This builder is used in the eager API to build a NodeDef.
663 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
664   using tensorflow::AttrBuilder::AttrBuilder;
665   // The string buffers to make sure that any `attr_name` we pass into
666   // `builder->Set()` will outlive the subsequent
667   // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
668   std::set<std::string> attr_names;
669 };
670 
TF_NewAttrBuilder(const char * op_name)671 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
672   return new TF_AttrBuilder(op_name);
673 }
674 
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)675 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
676 
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)677 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
678                            TF_DataType value) {
679   auto iter = builder->attr_names.insert(attr_name).first;
680   builder->Set(*iter, static_cast<tensorflow::DataType>(value));
681 }
682 
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)683 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
684                                const TF_DataType* values, int num_values) {
685   auto iter = builder->attr_names.insert(attr_name).first;
686   builder->Set(
687       (*iter).c_str(),
688       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
689           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
690 }
691 
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)692 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
693                                        const char* device_type,
694                                        TF_Status* status) {
695   status->status = tensorflow::FindKernelDef(
696       tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
697       /* def = */ nullptr, /* kernel_class_name = */ nullptr);
698 }
699 
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)700 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
701                                            TF_Status* status) {
702   const tensorflow::OpDef* op_def = nullptr;
703   status->status =
704       tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
705   if (!status->status.ok()) return nullptr;
706 
707   if (input_index >= op_def->input_arg_size() || input_index < 0) {
708     status->status = tensorflow::errors::InvalidArgument(
709         input_index, " out of range for ", op_name);
710     return nullptr;
711   }
712 
713   const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
714 
715   if (input_arg.number_attr().empty()) {
716     status->status = tensorflow::errors::NotFound(
717         op_name, " does not have number_attr() defined.");
718     return nullptr;
719   }
720 
721   // The returned string is owned by OpRegistry, so liveness is not a concern.
722   return input_arg.number_attr().c_str();
723 }
724 
TF_OpIsStateful(const char * op_type,TF_Status * status)725 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
726   const tensorflow::OpRegistrationData* op_reg_data;
727   status->status =
728       tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
729   if (!status->status.ok()) {
730     return 0;
731   }
732   return op_reg_data->op_def.is_stateful();
733 }
734 
TF_InitMain(const char * usage,int * argc,char *** argv)735 void TF_InitMain(const char* usage, int* argc, char*** argv) {
736   tensorflow::port::InitMain(usage, argc, argv);
737 }
738 
TF_PickUnusedPortOrDie()739 int TF_PickUnusedPortOrDie() {
740   return tensorflow::internal::PickUnusedPortOrDie();
741 }
742 
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)743 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
744                                                 void* data, size_t len,
745                                                 TF_Status* status) {
746   auto dtype = static_cast<tensorflow::DataType>(data_type);
747   DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
748 
749   tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
750   std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
751   return TFE_TensorHandle::CreateLocalHandle(tensor, status);
752 }
753 
754 namespace {
EnableCollectiveOps(const tensorflow::ServerDef & server_def,TFE_Context * ctx)755 tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
756                                        TFE_Context* ctx) {
757   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
758   // server object (which currently CHECK-fails) and we miss the error, instead,
759   // we log the error, and then return to allow the user to see the error
760   // message.
761 #define LOG_AND_RETURN_IF_ERROR(...)                    \
762   do {                                                  \
763     const ::tensorflow::Status _status = (__VA_ARGS__); \
764     if (TF_PREDICT_FALSE(!_status.ok())) {              \
765       LOG(ERROR) << _status.error_message();            \
766       return _status;                                   \
767     }                                                   \
768   } while (0);
769 
770   // New server created for new server_def. Unused if updating server_def.
771   tensorflow::EagerContext* context = ctx->context;
772   tensorflow::GrpcServer* grpc_server =
773       dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
774   if (grpc_server == nullptr) {
775     std::unique_ptr<tensorflow::ServerInterface> new_server;
776     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
777     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
778     if (grpc_server == nullptr) {
779       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
780           "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
781     }
782     LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
783 
784     LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
785         std::move(new_server), grpc_server->worker_env()->device_mgr,
786         grpc_server->worker_env()->collective_executor_mgr));
787   } else {
788     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
789     LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
790         /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
791         grpc_server->worker_env()->collective_executor_mgr));
792   }
793   return tensorflow::Status::OK();
794 #undef LOG_AND_RETURN_IF_ERROR
795 }
796 }  // namespace
797 
798 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)799 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
800                                                    const void* proto,
801                                                    size_t proto_len,
802                                                    TF_Status* status) {
803   tensorflow::ServerDef server_def;
804   if (!server_def.ParseFromArray(proto, proto_len)) {
805     status->status = tensorflow::errors::InvalidArgument(
806         "Invalid tensorflow.ServerDef protocol buffer");
807     return;
808   }
809   status->status = EnableCollectiveOps(server_def, ctx);
810 }
811 
TF_NewShapeAndTypeList(int num_items)812 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
813   TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
814   result->num_items = num_items;
815   result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
816   return result;
817 }
818 
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)819 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
820                                  const int64_t* dims, int num_dims) {
821   DCHECK(index >= 0 && index < shape_list->num_items);
822   TF_ShapeAndType& shape = shape_list->items[index];
823   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
824   DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
825   shape.num_dims = num_dims;
826   shape.dims = new int64_t[num_dims];
827   memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
828 }
829 
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)830 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
831                                         int index) {
832   DCHECK(index >= 0 && index < shape_list->num_items);
833   TF_ShapeAndType& shape = shape_list->items[index];
834   DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
835   shape.num_dims = -1;
836   shape.dims = nullptr;
837 }
838 
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)839 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
840                                  TF_DataType dtype) {
841   DCHECK(index >= 0 && index < shape_list->num_items);
842   TF_ShapeAndType& shape_and_type = shape_list->items[index];
843   shape_and_type.dtype = dtype;
844 }
845 
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)846 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
847   if (shape_list == nullptr) return;
848   for (size_t i = 0; i < shape_list->num_items; ++i) {
849     delete[] shape_list->items[i].dims;
850   }
851   delete[] shape_list->items;
852   delete shape_list;
853 }
854 
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)855 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
856                                     int num_items) {
857   if (shape_list_array == nullptr) return;
858   for (int i = 0; i < num_items; ++i) {
859     TF_DeleteShapeAndTypeList(shape_list_array[i]);
860   }
861   delete[] shape_list_array;
862 }
863 
864 namespace tensorflow {
865 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
866 }  // namespace tensorflow
867 
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)868 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
869                      TF_Tensor** input_tensors,
870                      TF_ShapeAndTypeList* input_tensors_as_shapes,
871                      TF_ShapeAndTypeList** input_resource_shapes_and_types,
872                      TF_ShapeAndTypeList** output_shapes,
873                      TF_ShapeAndTypeList*** output_resource_shapes_and_types,
874                      TF_Status* status) {
875   using tensorflow::NodeDef;
876   using tensorflow::OpRegistrationData;
877   using tensorflow::Tensor;
878   using tensorflow::shape_inference::DimensionHandle;
879   using tensorflow::shape_inference::InferenceContext;
880   using tensorflow::shape_inference::ShapeAndType;
881   using tensorflow::shape_inference::ShapeHandle;
882 
883   const int num_inputs = input_shapes->num_items;
884   NodeDef node_def;
885   node_def.set_name(tfe_op->operation.Name());
886   node_def.set_op(tfe_op->operation.Name());
887   for (int i = 0; i < num_inputs; ++i) {
888     node_def.add_input("dummy_input");
889   }
890   tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
891 
892   const tensorflow::OpRegistrationData* op_reg_data;
893   status->status =
894       tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
895   if (!status->status.ok()) return;
896 
897   // Initialize a input_tensor vector with `nullptr` values.
898   std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
899   // A vector to keep track of newly created `tf::Tensor` objects.
900   std::vector<Tensor> all_input_tensors;
901   // Update the vector with information from `input_tensors` if provided.
902   if (input_tensors != nullptr) {
903     // Note that we take the address of the elements in `all_input_tensors`
904     // below. Allocate enough space so that no reallocation happens, which will
905     // make the pointers invalid.
906     all_input_tensors.reserve(num_inputs);
907     for (int i = 0; i < num_inputs; ++i) {
908       if (input_tensors[i] == nullptr) continue;
909       all_input_tensors.emplace_back();
910       Tensor& input_tensor = all_input_tensors.back();
911       status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
912       if (!status->status.ok()) return;
913       input_tensors_vector[i] = &input_tensor;
914     }
915   }
916 
917   // Create an inference context with dummy values, which will be updated later.
918   InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
919                      std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
920                      {},
921                      std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
922 
923   // Set input_shapes.
924   for (int i = 0; i < num_inputs; ++i) {
925     std::vector<DimensionHandle> dims;
926     const TF_ShapeAndType& input_shape = input_shapes->items[i];
927     if (input_shape.num_dims == InferenceContext::kUnknownRank) {
928       c.SetInput(i, c.UnknownShape());
929       continue;
930     }
931     for (int j = 0; j < input_shape.num_dims; ++j) {
932       dims.push_back(c.MakeDim(input_shape.dims[j]));
933     }
934     c.SetInput(i, c.MakeShape(dims));
935   }
936 
937   // TODO(bgogul): Handle input_tensors_as_shapes.
938   // TODO(bgogul): Handle input_resource_shapes_and_types.
939 
940   status->status = c.construction_status();
941   if (!status->status.ok()) return;
942 
943   if (op_reg_data->shape_inference_fn == nullptr) {
944     status->status =
945         InvalidArgument("No shape inference function exists for op '",
946                         node_def.op(), "', did you forget to define it?");
947     return;
948   }
949 
950   status->status = c.Run(op_reg_data->shape_inference_fn);
951   if (!status->status.ok()) return;
952 
953   // Set output_shapes.
954   TF_ShapeAndTypeList* output_shapes_result =
955       TF_NewShapeAndTypeList(c.num_outputs());
956   for (int i = 0; i < c.num_outputs(); ++i) {
957     ShapeHandle shape_handle = c.output(i);
958     TF_ShapeAndType& shape = output_shapes_result->items[i];
959     shape.num_dims = c.Rank(shape_handle);
960     if (shape.num_dims == InferenceContext::kUnknownRank) {
961       shape.dims = nullptr;
962       continue;
963     }
964     shape.dims = new int64_t[shape.num_dims];
965     for (size_t j = 0; j < shape.num_dims; ++j) {
966       shape.dims[j] = c.Value(c.Dim(shape_handle, j));
967     }
968   }
969   if (output_shapes != nullptr) *output_shapes = output_shapes_result;
970 
971   // TODO(bgogul): Set output_resource_shapes_and_types.
972 }
973 
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)974 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
975     TF_ImportGraphDefOptions* opts, unsigned char enable) {
976   opts->opts.validate_colocation_constraints = enable;
977 }
978