1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_experimental.h"
17
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/shape_inference.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/init_main.h"
35 #include "tensorflow/core/platform/net.h"
36 #include "tensorflow/core/platform/platform.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
39
40 using tensorflow::FunctionDef;
41 using tensorflow::Node;
42 using tensorflow::NodeBuilder;
43 using tensorflow::Status;
44 using tensorflow::errors::InvalidArgument;
45
46 namespace {
47 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
48 UniqueFuncPtr;
49 }
50
51 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)52 static TF_Operation* ToTF_Operation(Node* node) {
53 return static_cast<TF_Operation*>(static_cast<void*>(node));
54 }
55
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)56 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
57 tensorflow::ConfigProto& config = options->options.config;
58 auto* optimizer_options =
59 config.mutable_graph_options()->mutable_optimizer_options();
60 if (enable) {
61 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
62
63 // These XLA flags are needed to trigger XLA properly from C (more generally
64 // non-Python) clients. If this API is called again with `enable` set to
65 // false, it is safe to keep these flag values as is.
66 tensorflow::MarkForCompilationPassFlags* flags =
67 tensorflow::GetMarkForCompilationPassFlags();
68 flags->tf_xla_cpu_global_jit = true;
69 flags->tf_xla_min_cluster_size = 1;
70 } else {
71 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
72 }
73 }
74
TF_SetXlaEnableLazyCompilation(unsigned char enable)75 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
76 tensorflow::BuildXlaOpsPassFlags* flags =
77 tensorflow::GetBuildXlaOpsPassFlags();
78 bool original = flags->tf_xla_enable_lazy_compilation;
79 flags->tf_xla_enable_lazy_compilation = enable;
80 return original;
81 }
82
TF_SetTfXlaCpuGlobalJit(unsigned char enable)83 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
84 tensorflow::MarkForCompilationPassFlags* flags =
85 tensorflow::GetMarkForCompilationPassFlags();
86 bool original = flags->tf_xla_cpu_global_jit;
87 flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
88 return static_cast<unsigned char>(original);
89 }
90
TF_SetXlaAutoJitMode(const char * mode)91 void TF_SetXlaAutoJitMode(const char* mode) {
92 tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
93 }
94
TF_GetXlaConstantFoldingDisabled()95 unsigned char TF_GetXlaConstantFoldingDisabled() {
96 return static_cast<unsigned char>(
97 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
98 }
99
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)100 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
101 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
102 static_cast<bool>(should_enable);
103 }
104
TF_SetXlaMinClusterSize(int size)105 void TF_SetXlaMinClusterSize(int size) {
106 tensorflow::MarkForCompilationPassFlags* flags =
107 tensorflow::GetMarkForCompilationPassFlags();
108 flags->tf_xla_min_cluster_size = size;
109 }
110
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)111 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
112 unsigned char gpu_memory_allow_growth,
113 unsigned int num_cpu_devices) {
114 tensorflow::ConfigProto config;
115 auto* optimizer_options =
116 config.mutable_graph_options()->mutable_optimizer_options();
117 if (enable_xla_compilation) {
118 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
119
120 // These XLA flags are needed to trigger XLA properly from C (more generally
121 // non-Python) clients. If this API is called again with `enable` set to
122 // false, it is safe to keep these flag values as is.
123 tensorflow::MarkForCompilationPassFlags* flags =
124 tensorflow::GetMarkForCompilationPassFlags();
125 flags->tf_xla_cpu_global_jit = true;
126 flags->tf_xla_min_cluster_size = 1;
127 } else {
128 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
129 }
130
131 auto* gpu_options = config.mutable_gpu_options();
132 gpu_options->set_allow_growth(gpu_memory_allow_growth);
133
134 (*config.mutable_device_count())["CPU"] = num_cpu_devices;
135
136 // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
137 // threadpool, so that we avoid the possibility of running the runner_ in the
138 // threadpool of GPU event mgr, as that can trigger more callbacks to be
139 // scheduled on that same threadpool, causing a deadlock in cases where the
140 // caller of event_mgr->ThenExecute() blocks on the completion of the callback
141 // (as in the case of ConstOp kernel creation on GPU, which involves copying a
142 // CPU tensor to GPU).
143 // Setting a larger thread pool does not help with the Swift caller, as we use
144 // a different TFE context for each thread of execution (for running graph
145 // functions, and their send/recvs corountines).
146 config.set_inter_op_parallelism_threads(1);
147
148 TF_Buffer* ret = TF_NewBuffer();
149 TF_CHECK_OK(MessageToBuffer(config, ret));
150 return ret;
151 }
152
TF_CreateRunOptions(unsigned char enable_full_trace)153 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
154 tensorflow::RunOptions options;
155 if (enable_full_trace) {
156 options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
157 } else {
158 options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
159 }
160 TF_Buffer* ret = TF_NewBuffer();
161 TF_CHECK_OK(MessageToBuffer(options, ret));
162 return ret;
163 }
164
TF_GraphDebugString(TF_Graph * graph,size_t * len)165 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
166 tensorflow::mutex_lock c(graph->mu);
167 const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
168 *len = debug_str.size();
169 char* ret = static_cast<char*>(malloc(*len + 1));
170 memcpy(ret, debug_str.c_str(), *len + 1);
171 return ret;
172 }
173
TF_FunctionDebugString(TF_Function * func,size_t * len)174 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
175 const auto& debug_str = DebugString(func->fdef);
176 *len = debug_str.size();
177 char* ret = static_cast<char*>(malloc(*len + 1));
178 memcpy(ret, debug_str.c_str(), *len + 1);
179 return ret;
180 }
181
182 // On success, returns a set of TF_Function instances from `text_proto` of
183 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
184 //
185 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
186 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)187 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
188 const char* text_proto,
189 std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
190 tensorflow::GraphDef gdef;
191 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
192 status->status = tensorflow::errors::Internal(
193 "Invalid text proto for GraphDef: ", text_proto);
194 return {};
195 }
196 const auto& fdef_lib = gdef.library();
197 if (fdef_lib.gradient_size() > 0) {
198 status->status = tensorflow::errors::Internal(
199 "GradientDef is not supported in reading Dataset related functions: ",
200 text_proto);
201 return {};
202 }
203 std::vector<UniqueFuncPtr> ret;
204 for (const FunctionDef& fdef : fdef_lib.function()) {
205 // Make a copy so that we can mutate it.
206 FunctionDef fdef_to_load = fdef;
207 if (mutate_proto_func) {
208 (*mutate_proto_func)(&fdef_to_load);
209 }
210 VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
211 std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
212 fdef_to_load.SerializeToArray(binary_proto_buf.data(),
213 binary_proto_buf.size());
214 TF_Function* func = TF_FunctionImportFunctionDef(
215 binary_proto_buf.data(), binary_proto_buf.size(), status);
216 if (!status->status.ok()) return {};
217 ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
218 }
219 return ret;
220 }
221
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)222 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
223 TF_Status* status) {
224 assert(session);
225 {
226 tensorflow::mutex_lock c(session->graph->mu);
227 VLOG(1) << "Dequeuing named tensor with id " << tensor_id
228 << ", with input graph: "
229 << session->graph->graph.ToGraphDefDebug().DebugString();
230 }
231
232 TF_Operation* dequeue_op = TF_GraphOperationByName(
233 session->graph,
234 tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
235 if (dequeue_op == nullptr) {
236 status->status = tensorflow::errors::Internal(
237 "Unable to find the dequeue node in the TF graph.");
238 return nullptr;
239 }
240
241 VLOG(1) << "Running the dequeue op";
242 TF_Output output{dequeue_op, 0};
243 TF_Tensor* ret;
244 TF_SessionRun(session, /*run_options*/ nullptr,
245 // input related parameters
246 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
247 // output related parameters
248 /*outputs*/ &output, /*output_values*/ &ret,
249 /*noutputs*/ 1,
250 /*targets*/ nullptr, /*ntargets*/ 0,
251 /*run_metadata*/ nullptr, status);
252 if (VLOG_IS_ON(1) && status->status.ok()) {
253 tensorflow::Tensor tensor;
254 if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
255 VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
256 }
257 }
258 return ret;
259 }
260
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)261 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
262 TF_Tensor* tensor, TF_Status* status) {
263 assert(session);
264 {
265 tensorflow::mutex_lock c(session->graph->mu);
266 if (VLOG_IS_ON(1)) {
267 VLOG(1) << "Enqueuing named tensor with id " << tensor_id
268 << ", with input graph: "
269 << session->graph->graph.ToGraphDefDebug().DebugString();
270 tensorflow::Tensor internal_tensor;
271 if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
272 VLOG(1) << "Enqueu'ing tensor content: "
273 << internal_tensor.DebugString();
274 }
275 }
276 }
277
278 TF_Operation* enqueue_op = TF_GraphOperationByName(
279 session->graph,
280 tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
281 if (enqueue_op == nullptr) {
282 status->status = tensorflow::errors::Internal(
283 "Unable to find the enqueue node in the TF graph.");
284 return;
285 }
286
287 TF_Operation* placeholder_op = TF_GraphOperationByName(
288 session->graph,
289 tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
290 if (placeholder_op == nullptr) {
291 status->status = tensorflow::errors::Internal(
292 "Unable to find the placeholder node as input to enqueue in the TF "
293 "graph.");
294 return;
295 }
296
297 VLOG(1) << "Running the enqueue op";
298 TF_Output input{placeholder_op, 0};
299 TF_SessionRun(session, /*run_options*/ nullptr,
300 // input related parameters
301 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
302 // output related parameters
303 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
304 /*targets*/ &enqueue_op, /*ntargets*/ 1,
305 /*run_metadata*/ nullptr, status);
306 VLOG(1) << "Enqueuing is done.";
307 }
308
TFE_GetServerDef(const char * text_proto,TF_Status * status)309 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
310 tensorflow::ServerDef server_def;
311 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
312 &server_def)) {
313 status->status = tensorflow::errors::Internal(
314 "Invalid text proto for ServerDef: ", text_proto);
315 return nullptr;
316 }
317 status->status = tensorflow::Status();
318 TF_Buffer* ret = TF_NewBuffer();
319 TF_CHECK_OK(MessageToBuffer(server_def, ret));
320 return ret;
321 }
322
TFE_CreateContextFromSession(TF_Session * session,TF_Status * status)323 TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
324 TF_Status* status) {
325 auto* opts = TFE_NewContextOptions();
326
327 // Reduce GPU memory allocation, and set appropriate config options for TFE
328 // context.
329 auto* config = TF_CreateConfig(
330 /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
331 10);
332 TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
333 if (!status->status.ok()) {
334 CHECK(!config);
335 TFE_DeleteContextOptions(opts);
336 return nullptr;
337 }
338
339 auto* ctx = TFE_NewContextFromSession(opts, session, status);
340 TF_DeleteBuffer(config);
341 TFE_DeleteContextOptions(opts);
342 return ctx;
343 }
344
345 // TODO: retrieve the device string via TFE_ContextListDevices()
346 static const char DEFAULT_CPU_DEVICE[] =
347 "/job:localhost/replica:0/task:0/device:CPU:0";
348
createTFEQueue(TFE_Context * ctx,TF_DataType inputType,int tensor_id,TF_Status * status)349 static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
350 int tensor_id, TF_Status* status) {
351 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
352 TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
353 TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
354 if (!status->status.ok()) return nullptr;
355 // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
356 TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
357 TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
358 auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
359 TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
360 shared_name.size());
361 TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
362
363 // TODO: consider making this an unknown shape.
364 const int64_t* dims_ptr = nullptr;
365 int num_dims = 0;
366 TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
367 /*num_values*/ 0, status);
368 if (!status->status.ok()) return nullptr;
369
370 int num_retvals = 1;
371 TFE_TensorHandle* queue = nullptr;
372 TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
373 if (!status->status.ok()) return nullptr;
374 CHECK_EQ(num_retvals, 1);
375
376 return queue;
377 }
378
createTFEEnqueue(TFE_Context * ctx,TF_DataType inputType,TFE_TensorHandle * queue,TFE_TensorHandle * tensor,TF_Status * status)379 static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
380 TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
381 TF_Status* status) {
382 TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
383 if (!status->status.ok()) return;
384 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
385 TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
386 if (!status->status.ok()) return;
387 TFE_OpAddInput(op, queue, status);
388 if (!status->status.ok()) return;
389 TFE_OpAddInput(op, tensor, status);
390 if (!status->status.ok()) return;
391 TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
392 TFE_OpSetAttrInt(op, "timeout_ms", -1);
393
394 int num_retvals = 0;
395 TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
396 if (!status->status.ok()) return;
397 CHECK_EQ(num_retvals, 0);
398 }
399
createTFEDequeue(TFE_Context * ctx,TF_DataType inputType,TFE_TensorHandle * queue,TF_Status * status)400 static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
401 TF_DataType inputType,
402 TFE_TensorHandle* queue,
403 TF_Status* status) {
404 TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
405 if (!status->status.ok()) return nullptr;
406 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
407 TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
408 if (!status->status.ok()) return nullptr;
409
410 TFE_OpAddInput(op, queue, status);
411 if (!status->status.ok()) return nullptr;
412 TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
413 TFE_OpSetAttrInt(op, "timeout_ms", -1);
414 TFE_TensorHandle* ret;
415 int num_retvals = 1;
416 TFE_Execute(op, &ret, &num_retvals, status);
417 if (!status->status.ok()) return nullptr;
418 CHECK_EQ(num_retvals, 1);
419 return ret;
420 }
421
TFE_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_DataType inputType,TF_Status * status)422 TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
423 TF_DataType inputType,
424 TF_Status* status) {
425 assert(session);
426 VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
427
428 auto ctx = TFE_CreateContextFromSession(session, status);
429 if (!status->status.ok()) return nullptr;
430 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
431 ctx, TFE_DeleteContext);
432
433 TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
434 if (!status->status.ok()) return nullptr;
435 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
436 queue_deleter(queue, TFE_DeleteTensorHandle);
437
438 auto* ret = createTFEDequeue(ctx, inputType, queue, status);
439 return ret;
440 }
441
TFE_DequeueNamedTensorFromCtx(TFE_Context * ctx,int tensor_id,TF_DataType inputType,TF_Status * status)442 TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
443 TF_DataType inputType,
444 TF_Status* status) {
445 TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
446 if (!status->status.ok()) return nullptr;
447 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
448 queue_deleter(queue, TFE_DeleteTensorHandle);
449
450 auto* ret = createTFEDequeue(ctx, inputType, queue, status);
451
452 return ret;
453 }
454
TFE_EnqueueNamedTensor(TF_Session * session,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)455 void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
456 TFE_TensorHandle* tensor, TF_Status* status) {
457 assert(session);
458 VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
459
460 auto ctx = TFE_CreateContextFromSession(session, status);
461 if (!status->status.ok()) return;
462 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
463 ctx, TFE_DeleteContext);
464
465 TF_DataType inputType = TFE_TensorHandleDataType(tensor);
466 TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
467 if (!status->status.ok()) return;
468 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
469 queue_deleter(queue, TFE_DeleteTensorHandle);
470
471 createTFEEnqueue(ctx, inputType, queue, tensor, status);
472 }
473
TFE_EnqueueNamedTensorFromCtx(TFE_Context * ctx,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)474 void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
475 TFE_TensorHandle* tensor,
476 TF_Status* status) {
477 VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
478
479 TF_DataType inputType = TFE_TensorHandleDataType(tensor);
480 TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
481 if (!status->status.ok()) return;
482 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
483 queue_deleter(queue, TFE_DeleteTensorHandle);
484
485 createTFEEnqueue(ctx, inputType, queue, tensor, status);
486 }
487
TFE_EnqueueVariantTensor(TF_Session * session,int tensor_id,TFE_TensorHandle * tensor,TF_Status * status)488 void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
489 TFE_TensorHandle* tensor, TF_Status* status) {
490 VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
491
492 auto ctx = TFE_CreateContextFromSession(session, status);
493 if (!status->status.ok()) return;
494 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
495 ctx, TFE_DeleteContext);
496
497 TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
498 if (!status->status.ok()) return;
499 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
500 queue_deleter(queue, TFE_DeleteTensorHandle);
501
502 createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
503 }
504
TFE_DequeueVariantTensor(TF_Session * session,int tensor_id,TF_Status * status)505 TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
506 TF_Status* status) {
507 VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
508
509 auto ctx = TFE_CreateContextFromSession(session, status);
510 if (!status->status.ok()) return nullptr;
511 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
512 ctx, TFE_DeleteContext);
513
514 TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
515 if (!status->status.ok()) return nullptr;
516 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
517 queue_deleter(queue, TFE_DeleteTensorHandle);
518
519 return createTFEDequeue(ctx, TF_VARIANT, queue, status);
520 }
521
TFE_TensorHandlePrintDebugString(TFE_TensorHandle * handle)522 void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
523 auto* status = TF_NewStatus();
524 TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
525 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
526
527 tensorflow::Tensor dst;
528 TF_CHECK_OK(TF_TensorToTensor(t, &dst));
529 LOG(INFO) << dst.DebugString();
530
531 TF_DeleteTensor(t);
532 TF_DeleteStatus(status);
533 }
534
TFE_OpPrintDebugString(TFE_Op * op)535 void TFE_OpPrintDebugString(TFE_Op* op) {
536 VLOG(1) << "TFE_OpPrintDebugString() over " << op;
537 LOG(INFO) << op->operation.DebugString();
538 }
539
540 struct TFE_ExecuteOpNotification {
TFE_ExecuteOpNotificationTFE_ExecuteOpNotification541 TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
542 tensorflow::Notification n;
543 std::unique_ptr<tensorflow::Thread> thread;
544 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
545 };
546
TFE_ExecuteOpInNewThread(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)547 TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
548 TFE_TensorHandle** retvals,
549 int* num_retvals,
550 TF_Status* status) {
551 TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
552
553 n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
554 tensorflow::ThreadOptions(), "ExecuteOpThread",
555 [op, retvals, num_retvals, n]() {
556 TFE_Execute(op, retvals, num_retvals, n->status.get());
557 n->n.Notify();
558 }));
559
560 return n;
561 }
562
TFE_ExecuteOpNotificationWaitAndDelete(TFE_ExecuteOpNotification * notification,TF_Status * status)563 void TFE_ExecuteOpNotificationWaitAndDelete(
564 TFE_ExecuteOpNotification* notification, TF_Status* status) {
565 if (notification == nullptr) {
566 status->status = tensorflow::errors::InvalidArgument(
567 "Passed in notification is a nullptr.");
568
569 return;
570 }
571 if (notification->thread == nullptr) {
572 status->status = tensorflow::errors::InvalidArgument(
573 "Passed in notification didn't start a thread correctly. Cleaning up "
574 "this notification. Please re-execute the operation to get a new "
575 "notification.");
576
577 delete notification;
578 return;
579 }
580
581 notification->n.WaitForNotification();
582
583 status->status = notification->status->status;
584
585 delete notification;
586 }
587
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)588 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
589 status->status = tensorflow::errors::Internal(errMsg);
590 }
591
592 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
593 using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
594 std::vector<std::string> variable_list;
595 };
596
TF_NewCheckpointReader(const char * filename,TF_Status * status)597 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
598 TF_Status* status) {
599 TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
600 if (!status->status.ok()) {
601 TF_DeleteCheckpointReader(reader);
602 return nullptr;
603 }
604 const auto& m = reader->GetVariableToDataTypeMap();
605 for (auto it = m.begin(); it != m.end(); ++it)
606 reader->variable_list.push_back(it->first);
607 std::sort(reader->variable_list.begin(), reader->variable_list.end());
608 return reader;
609 }
610
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)611 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
612
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)613 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
614 const char* name) {
615 return reader->HasTensor(name);
616 }
617
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)618 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
619 int index) {
620 return reader->variable_list[index].c_str();
621 }
622
TF_CheckpointReaderSize(TF_CheckpointReader * reader)623 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
624 return reader->variable_list.size();
625 }
626
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)627 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
628 const char* name) {
629 const auto& m = reader->GetVariableToDataTypeMap();
630 return static_cast<TF_DataType>(m.at(name));
631 }
632
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)633 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
634 const char* name, TF_Status* status) {
635 std::unique_ptr<tensorflow::Tensor> tensor;
636 reader->GetTensor(name, &tensor, status);
637 if (!status->status.ok()) return nullptr;
638 return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
639 }
640
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)641 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
642 const char* name, int64_t* dims,
643 int num_dims, TF_Status* status) {
644 const auto& shape = reader->GetVariableToShapeMap().at(name);
645 int rank = shape.dims();
646 if (num_dims != rank) {
647 status->status = InvalidArgument("Expected rank is ", num_dims,
648 " but actual rank is ", rank);
649 return;
650 }
651 for (int i = 0; i < num_dims; i++) {
652 dims[i] = shape.dim_size(i);
653 }
654 }
655
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)656 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
657 const char* name) {
658 const auto& m = reader->GetVariableToShapeMap();
659 return m.at(name).dims();
660 }
661
662 // This builder is used in the eager API to build a NodeDef.
663 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
664 using tensorflow::AttrBuilder::AttrBuilder;
665 // The string buffers to make sure that any `attr_name` we pass into
666 // `builder->Set()` will outlive the subsequent
667 // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
668 std::set<std::string> attr_names;
669 };
670
TF_NewAttrBuilder(const char * op_name)671 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
672 return new TF_AttrBuilder(op_name);
673 }
674
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)675 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
676
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)677 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
678 TF_DataType value) {
679 auto iter = builder->attr_names.insert(attr_name).first;
680 builder->Set(*iter, static_cast<tensorflow::DataType>(value));
681 }
682
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)683 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
684 const TF_DataType* values, int num_values) {
685 auto iter = builder->attr_names.insert(attr_name).first;
686 builder->Set(
687 (*iter).c_str(),
688 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
689 reinterpret_cast<const tensorflow::DataType*>(values), num_values));
690 }
691
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)692 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
693 const char* device_type,
694 TF_Status* status) {
695 status->status = tensorflow::FindKernelDef(
696 tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
697 /* def = */ nullptr, /* kernel_class_name = */ nullptr);
698 }
699
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)700 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
701 TF_Status* status) {
702 const tensorflow::OpDef* op_def = nullptr;
703 status->status =
704 tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
705 if (!status->status.ok()) return nullptr;
706
707 if (input_index >= op_def->input_arg_size() || input_index < 0) {
708 status->status = tensorflow::errors::InvalidArgument(
709 input_index, " out of range for ", op_name);
710 return nullptr;
711 }
712
713 const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
714
715 if (input_arg.number_attr().empty()) {
716 status->status = tensorflow::errors::NotFound(
717 op_name, " does not have number_attr() defined.");
718 return nullptr;
719 }
720
721 // The returned string is owned by OpRegistry, so liveness is not a concern.
722 return input_arg.number_attr().c_str();
723 }
724
TF_OpIsStateful(const char * op_type,TF_Status * status)725 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
726 const tensorflow::OpRegistrationData* op_reg_data;
727 status->status =
728 tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
729 if (!status->status.ok()) {
730 return 0;
731 }
732 return op_reg_data->op_def.is_stateful();
733 }
734
TF_InitMain(const char * usage,int * argc,char *** argv)735 void TF_InitMain(const char* usage, int* argc, char*** argv) {
736 tensorflow::port::InitMain(usage, argc, argv);
737 }
738
TF_PickUnusedPortOrDie()739 int TF_PickUnusedPortOrDie() {
740 return tensorflow::internal::PickUnusedPortOrDie();
741 }
742
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)743 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
744 void* data, size_t len,
745 TF_Status* status) {
746 auto dtype = static_cast<tensorflow::DataType>(data_type);
747 DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
748
749 tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
750 std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
751 return TFE_TensorHandle::CreateLocalHandle(tensor, status);
752 }
753
754 namespace {
EnableCollectiveOps(const tensorflow::ServerDef & server_def,TFE_Context * ctx)755 tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
756 TFE_Context* ctx) {
757 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
758 // server object (which currently CHECK-fails) and we miss the error, instead,
759 // we log the error, and then return to allow the user to see the error
760 // message.
761 #define LOG_AND_RETURN_IF_ERROR(...) \
762 do { \
763 const ::tensorflow::Status _status = (__VA_ARGS__); \
764 if (TF_PREDICT_FALSE(!_status.ok())) { \
765 LOG(ERROR) << _status.error_message(); \
766 return _status; \
767 } \
768 } while (0);
769
770 // New server created for new server_def. Unused if updating server_def.
771 tensorflow::EagerContext* context = ctx->context;
772 tensorflow::GrpcServer* grpc_server =
773 dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
774 if (grpc_server == nullptr) {
775 std::unique_ptr<tensorflow::ServerInterface> new_server;
776 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
777 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
778 if (grpc_server == nullptr) {
779 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
780 "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
781 }
782 LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
783
784 LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
785 std::move(new_server), grpc_server->worker_env()->device_mgr,
786 grpc_server->worker_env()->collective_executor_mgr));
787 } else {
788 LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
789 LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
790 /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
791 grpc_server->worker_env()->collective_executor_mgr));
792 }
793 return tensorflow::Status::OK();
794 #undef LOG_AND_RETURN_IF_ERROR
795 }
796 } // namespace
797
798 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)799 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
800 const void* proto,
801 size_t proto_len,
802 TF_Status* status) {
803 tensorflow::ServerDef server_def;
804 if (!server_def.ParseFromArray(proto, proto_len)) {
805 status->status = tensorflow::errors::InvalidArgument(
806 "Invalid tensorflow.ServerDef protocol buffer");
807 return;
808 }
809 status->status = EnableCollectiveOps(server_def, ctx);
810 }
811
TF_NewShapeAndTypeList(int num_items)812 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
813 TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
814 result->num_items = num_items;
815 result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
816 return result;
817 }
818
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)819 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
820 const int64_t* dims, int num_dims) {
821 DCHECK(index >= 0 && index < shape_list->num_items);
822 TF_ShapeAndType& shape = shape_list->items[index];
823 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
824 DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
825 shape.num_dims = num_dims;
826 shape.dims = new int64_t[num_dims];
827 memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
828 }
829
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)830 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
831 int index) {
832 DCHECK(index >= 0 && index < shape_list->num_items);
833 TF_ShapeAndType& shape = shape_list->items[index];
834 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
835 shape.num_dims = -1;
836 shape.dims = nullptr;
837 }
838
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)839 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
840 TF_DataType dtype) {
841 DCHECK(index >= 0 && index < shape_list->num_items);
842 TF_ShapeAndType& shape_and_type = shape_list->items[index];
843 shape_and_type.dtype = dtype;
844 }
845
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)846 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
847 if (shape_list == nullptr) return;
848 for (size_t i = 0; i < shape_list->num_items; ++i) {
849 delete[] shape_list->items[i].dims;
850 }
851 delete[] shape_list->items;
852 delete shape_list;
853 }
854
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)855 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
856 int num_items) {
857 if (shape_list_array == nullptr) return;
858 for (int i = 0; i < num_items; ++i) {
859 TF_DeleteShapeAndTypeList(shape_list_array[i]);
860 }
861 delete[] shape_list_array;
862 }
863
864 namespace tensorflow {
865 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
866 } // namespace tensorflow
867
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)868 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
869 TF_Tensor** input_tensors,
870 TF_ShapeAndTypeList* input_tensors_as_shapes,
871 TF_ShapeAndTypeList** input_resource_shapes_and_types,
872 TF_ShapeAndTypeList** output_shapes,
873 TF_ShapeAndTypeList*** output_resource_shapes_and_types,
874 TF_Status* status) {
875 using tensorflow::NodeDef;
876 using tensorflow::OpRegistrationData;
877 using tensorflow::Tensor;
878 using tensorflow::shape_inference::DimensionHandle;
879 using tensorflow::shape_inference::InferenceContext;
880 using tensorflow::shape_inference::ShapeAndType;
881 using tensorflow::shape_inference::ShapeHandle;
882
883 const int num_inputs = input_shapes->num_items;
884 NodeDef node_def;
885 node_def.set_name(tfe_op->operation.Name());
886 node_def.set_op(tfe_op->operation.Name());
887 for (int i = 0; i < num_inputs; ++i) {
888 node_def.add_input("dummy_input");
889 }
890 tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
891
892 const tensorflow::OpRegistrationData* op_reg_data;
893 status->status =
894 tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
895 if (!status->status.ok()) return;
896
897 // Initialize a input_tensor vector with `nullptr` values.
898 std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
899 // A vector to keep track of newly created `tf::Tensor` objects.
900 std::vector<Tensor> all_input_tensors;
901 // Update the vector with information from `input_tensors` if provided.
902 if (input_tensors != nullptr) {
903 // Note that we take the address of the elements in `all_input_tensors`
904 // below. Allocate enough space so that no reallocation happens, which will
905 // make the pointers invalid.
906 all_input_tensors.reserve(num_inputs);
907 for (int i = 0; i < num_inputs; ++i) {
908 if (input_tensors[i] == nullptr) continue;
909 all_input_tensors.emplace_back();
910 Tensor& input_tensor = all_input_tensors.back();
911 status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
912 if (!status->status.ok()) return;
913 input_tensors_vector[i] = &input_tensor;
914 }
915 }
916
917 // Create an inference context with dummy values, which will be updated later.
918 InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
919 std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
920 {},
921 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
922
923 // Set input_shapes.
924 for (int i = 0; i < num_inputs; ++i) {
925 std::vector<DimensionHandle> dims;
926 const TF_ShapeAndType& input_shape = input_shapes->items[i];
927 if (input_shape.num_dims == InferenceContext::kUnknownRank) {
928 c.SetInput(i, c.UnknownShape());
929 continue;
930 }
931 for (int j = 0; j < input_shape.num_dims; ++j) {
932 dims.push_back(c.MakeDim(input_shape.dims[j]));
933 }
934 c.SetInput(i, c.MakeShape(dims));
935 }
936
937 // TODO(bgogul): Handle input_tensors_as_shapes.
938 // TODO(bgogul): Handle input_resource_shapes_and_types.
939
940 status->status = c.construction_status();
941 if (!status->status.ok()) return;
942
943 if (op_reg_data->shape_inference_fn == nullptr) {
944 status->status =
945 InvalidArgument("No shape inference function exists for op '",
946 node_def.op(), "', did you forget to define it?");
947 return;
948 }
949
950 status->status = c.Run(op_reg_data->shape_inference_fn);
951 if (!status->status.ok()) return;
952
953 // Set output_shapes.
954 TF_ShapeAndTypeList* output_shapes_result =
955 TF_NewShapeAndTypeList(c.num_outputs());
956 for (int i = 0; i < c.num_outputs(); ++i) {
957 ShapeHandle shape_handle = c.output(i);
958 TF_ShapeAndType& shape = output_shapes_result->items[i];
959 shape.num_dims = c.Rank(shape_handle);
960 if (shape.num_dims == InferenceContext::kUnknownRank) {
961 shape.dims = nullptr;
962 continue;
963 }
964 shape.dims = new int64_t[shape.num_dims];
965 for (size_t j = 0; j < shape.num_dims; ++j) {
966 shape.dims[j] = c.Value(c.Dim(shape_handle, j));
967 }
968 }
969 if (output_shapes != nullptr) *output_shapes = output_shapes_result;
970
971 // TODO(bgogul): Set output_resource_shapes_and_types.
972 }
973
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)974 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
975 TF_ImportGraphDefOptions* opts, unsigned char enable) {
976 opts->opts.validate_colocation_constraints = enable;
977 }
978