1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h" // NOLINT
26
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/cc/framework/gradients.h"
29 #include "tensorflow/cc/framework/ops.h"
30 #include "tensorflow/cc/framework/scope_internal.h"
31 #include "tensorflow/cc/ops/while_loop.h"
32 #include "tensorflow/cc/saved_model/loader.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/framework/logging.h"
35 #include "tensorflow/core/framework/op_gen_lib.h"
36 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
37 #include "tensorflow/c/c_api_internal.h"
38 #include "tensorflow/c/tf_status_internal.h"
39 #include "tensorflow/c/tf_tensor.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
42 #include "tensorflow/core/common_runtime/shape_refiner.h"
43 #include "tensorflow/core/framework/allocation_description.pb.h"
44 #include "tensorflow/core/framework/kernel_def.pb.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/op_kernel.h"
48 #include "tensorflow/core/framework/partial_tensor_shape.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/framework/versions.pb.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/graph/graph_constructor.h"
57 #include "tensorflow/core/graph/node_builder.h"
58 #include "tensorflow/core/graph/validate.h"
59 #include "tensorflow/core/lib/core/coding.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/lib/core/status.h"
62 #include "tensorflow/core/lib/core/stringpiece.h"
63 #include "tensorflow/core/lib/gtl/array_slice.h"
64 #include "tensorflow/core/lib/strings/str_util.h"
65 #include "tensorflow/core/lib/strings/strcat.h"
66 #include "tensorflow/core/platform/mem.h"
67 #include "tensorflow/core/platform/mutex.h"
68 #include "tensorflow/core/platform/protobuf.h"
69 #include "tensorflow/core/platform/thread_annotations.h"
70 #include "tensorflow/core/platform/types.h"
71 #include "tensorflow/core/public/session.h"
72 #include "tensorflow/core/public/version.h"
73
74 // The implementation below is at the top level instead of the
75 // brain namespace because we are defining 'extern "C"' functions.
76 using tensorflow::AllocationDescription;
77 using tensorflow::DataType;
78 using tensorflow::ExtendSessionGraphHelper;
79 using tensorflow::Graph;
80 using tensorflow::GraphDef;
81 using tensorflow::mutex_lock;
82 using tensorflow::NameRangeMap;
83 using tensorflow::NameRangesForNode;
84 using tensorflow::NewSession;
85 using tensorflow::Node;
86 using tensorflow::NodeBuilder;
87 using tensorflow::NodeDef;
88 using tensorflow::OpDef;
89 using tensorflow::OpRegistry;
90 using tensorflow::OutputTensor;
91 using tensorflow::PartialTensorShape;
92 using tensorflow::RunMetadata;
93 using tensorflow::RunOptions;
94 using tensorflow::Session;
95 using tensorflow::Status;
96 using tensorflow::string;
97 using tensorflow::Tensor;
98 using tensorflow::TensorBuffer;
99 using tensorflow::TensorId;
100 using tensorflow::TensorShape;
101 using tensorflow::TensorShapeProto;
102 using tensorflow::VersionDef;
103 using tensorflow::errors::FailedPrecondition;
104 using tensorflow::errors::InvalidArgument;
105 using tensorflow::gtl::ArraySlice;
106 using tensorflow::strings::StrCat;
107
108 extern "C" {
109
110 // --------------------------------------------------------------------------
TF_Version()111 const char* TF_Version() { return TF_VERSION_STRING; }
112
113 // --------------------------------------------------------------------------
114
115 // --------------------------------------------------------------------------
TF_NewSessionOptions()116 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)117 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
118
TF_SetTarget(TF_SessionOptions * options,const char * target)119 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
120 options->options.target = target;
121 }
122
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)123 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
124 size_t proto_len, TF_Status* status) {
125 if (!options->options.config.ParseFromArray(proto, proto_len)) {
126 status->status = InvalidArgument("Unparseable ConfigProto");
127 }
128 }
129 // --------------------------------------------------------------------------
TF_NewBuffer()130 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
131
TF_NewBufferFromString(const void * proto,size_t proto_len)132 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
133 void* copy = tensorflow::port::Malloc(proto_len);
134 memcpy(copy, proto, proto_len);
135
136 TF_Buffer* buf = new TF_Buffer;
137 buf->data = copy;
138 buf->length = proto_len;
139 buf->data_deallocator = [](void* data, size_t length) {
140 tensorflow::port::Free(data);
141 };
142 return buf;
143 }
144
TF_DeleteBuffer(TF_Buffer * buffer)145 void TF_DeleteBuffer(TF_Buffer* buffer) {
146 if (buffer == nullptr) return;
147 if (buffer->data_deallocator != nullptr) {
148 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
149 buffer->length);
150 }
151 delete buffer;
152 }
153
TF_GetBuffer(TF_Buffer * buffer)154 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
155
156 // --------------------------------------------------------------------------
157
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)158 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
159 TF_Status* status) {
160 Session* session;
161 status->status = NewSession(opt->options, &session);
162 if (status->status.ok()) {
163 return new TF_DeprecatedSession({session});
164 } else {
165 DCHECK_EQ(nullptr, session);
166 return nullptr;
167 }
168 }
169
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)170 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
171 status->status = s->session->Close();
172 }
173
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)174 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
175 status->status = Status::OK();
176 if (s == nullptr) return;
177 delete s->session;
178 delete s;
179 }
180
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)181 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
182 size_t proto_len, TF_Status* status) {
183 GraphDef g;
184 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
185 status->status = InvalidArgument("Invalid GraphDef");
186 return;
187 }
188 status->status = s->session->Extend(g);
189 }
190
191 } // end extern "C"
192
193 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)194 static void TF_Reset_Helper(const TF_SessionOptions* opt,
195 const char** containers, int ncontainers,
196 TF_Status* status) {
197 std::vector<string> container_names(ncontainers);
198 for (int i = 0; i < ncontainers; ++i) {
199 container_names[i] = containers[i];
200 }
201
202 status->status = Reset(opt->options, container_names);
203 }
204
205 extern "C" {
206
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)207 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
208 int ncontainers, TF_Status* status) {
209 TF_Reset_Helper(opt, containers, ncontainers, status);
210 }
211
212 } // end extern "C"
213
214 namespace tensorflow {
215
216
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)217 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
218 TF_Buffer* out) {
219 if (out->data != nullptr) {
220 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
221 }
222 const size_t proto_size = in.ByteSizeLong();
223 void* buf = port::Malloc(proto_size);
224 if (buf == nullptr) {
225 return tensorflow::errors::ResourceExhausted(
226 "Failed to allocate memory to serialize message of type '",
227 in.GetTypeName(), "' and size ", proto_size);
228 }
229 if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
230 port::Free(buf);
231 return InvalidArgument("Unable to serialize ", in.GetTypeName(),
232 " protocol buffer, perhaps the serialized size (",
233 proto_size, " bytes) is too large?");
234 }
235 out->data = buf;
236 out->length = proto_size;
237 out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
238 return Status::OK();
239 }
240
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)241 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
242 const char* mutation_type) {
243 // If any session has already run this node_id, mark this session as
244 // unrunnable.
245 for (auto it : graph->sessions) {
246 mutex_lock session_lock(it.first->mu);
247 if (it.first->last_num_graph_nodes > op.node.id()) {
248 it.second = strings::StrCat(
249 "Operation '", op.node.DebugString(), "' was changed by ",
250 mutation_type,
251 " after it was run by a session. This mutation will have no effect, "
252 "and will trigger an error in the future. Either don't modify "
253 "nodes after running them or create a new session.");
254 }
255 }
256 }
257
258 namespace {
259
260 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)261 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
262 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
263 const int64_t* dims) {
264 if (num_dims != -1) {
265 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
266 dim_vec.reserve(num_dims);
267 for (int i = 0; i < num_dims; ++i) {
268 dim_vec.push_back(ic->MakeDim(dims[i]));
269 }
270 return ic->MakeShape(dim_vec);
271 } else {
272 return ic->UnknownShape();
273 }
274 }
275
276 } // namespace
277
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)278 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
279 int num_shapes_and_types,
280 const int64_t** shapes,
281 const int* ranks,
282 const TF_DataType* types,
283 TF_Status* status) {
284 Node* node = &output.oper->node;
285
286 mutex_lock l(graph->mu);
287 tensorflow::shape_inference::InferenceContext* ic =
288 graph->refiner.GetContext(node);
289 if (ic == nullptr) {
290 status->status =
291 InvalidArgument("Node ", node->name(), " was not found in the graph");
292 return;
293 }
294
295 auto shape_and_type_vec =
296 std::vector<tensorflow::shape_inference::ShapeAndType>(
297 num_shapes_and_types);
298 for (int i = 0; i < num_shapes_and_types; ++i) {
299 tensorflow::shape_inference::ShapeHandle shape_handle =
300 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
301 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
302 shape_handle, static_cast<DataType>(types[i]));
303 }
304
305 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
306 }
307
308 // Helpers for loading a TensorFlow plugin (a .so file).
309 Status LoadLibrary(const char* library_filename, void** result,
310 const void** buf, size_t* len);
311
312 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
313 // directly, instead of requiring us to serialize to a GraphDef and
314 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)315 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
316 if (session->graph != nullptr) {
317 // Take the graph lock before the session lock to avoid deadlock. This is
318 // safe since session->graph does not change.
319 session->graph->mu.lock();
320 mutex_lock session_lock(session->mu);
321 const Graph& graph = session->graph->graph;
322
323 const string& mutation_warning = session->graph->sessions[session];
324 if (!mutation_warning.empty()) {
325 // TODO(b/74949947): turn this back into an error status
326 LOG(WARNING) << mutation_warning;
327 session->graph->sessions[session].clear();
328 }
329
330 const auto num_nodes = graph.num_node_ids();
331 if (session->last_num_graph_nodes < num_nodes) {
332 // TODO(nolivia): check this on a subset of the graph instead of all of
333 // it.
334 status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
335 if (!status->status.ok()) {
336 session->graph->mu.unlock();
337 return false;
338 }
339
340 GraphDef graph_def;
341 *graph_def.mutable_versions() = graph.versions();
342 // Fill graph_def with nodes with ids in the range
343 // [session->last_num_graph_nodes, num_nodes), that is the nodes
344 // added since the last TF_SessionRun() call.
345 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
346 Node* const node = graph.FindNodeId(id);
347 if (node != nullptr && node->IsOp()) {
348 NodeDef* const node_def = graph_def.add_node();
349 *node_def = node->def();
350 }
351 }
352 *graph_def.mutable_library() = graph.flib_def().ToProto();
353 session->graph->mu.unlock();
354 status->status = session->session->Extend(std::move(graph_def));
355 if (!status->status.ok()) {
356 // Contract is we always delete input_values[i].
357 return false;
358 }
359 // Note: session->session is not modified if Extend() fails, so
360 // we only set last_num_graph_nodes if it succeeds.
361 session->last_num_graph_nodes = num_nodes;
362 } else {
363 session->graph->mu.unlock();
364 }
365 }
366 return true;
367 }
368
369 } // namespace tensorflow
370
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)371 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
372 TF_Status* status) {
373 status->status = Status::OK();
374 for (int i = 0; i < noutputs; ++i) {
375 c_outputs[i] = nullptr;
376 }
377 }
378
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380 std::vector<std::pair<string, Tensor>>* input_pairs,
381 TF_Status* status) {
382 const int ninputs = input_pairs->size();
383 for (int i = 0; i < ninputs; ++i) {
384 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
385 if (!status->status.ok()) return false;
386 }
387 return true;
388 }
389
390 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
391 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)392 static TF_Tensor* EmptyTensor(TF_DataType dtype,
393 const tensorflow::TensorShape& shape) {
394 static char empty;
395 tensorflow::int64 nelems = 1;
396 std::vector<tensorflow::int64> dims;
397 for (int i = 0; i < shape.dims(); ++i) {
398 dims.push_back(shape.dim_size(i));
399 nelems *= shape.dim_size(i);
400 }
401 CHECK_EQ(nelems, 0);
402 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
403 "64-bit int types should match in size");
404 return TF_NewTensor(
405 dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406 reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410 Session* session, const char* handle, const TF_Buffer* run_options,
411 // Input tensors
412 const std::vector<std::pair<string, Tensor>>& input_pairs,
413 // Output tensors
414 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415 // Target nodes
416 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417 TF_Status* status) {
418 const int noutputs = output_tensor_names.size();
419 std::vector<Tensor> outputs(noutputs);
420 Status result;
421
422 if (handle == nullptr) {
423 RunOptions run_options_proto;
424 if (run_options != nullptr && !run_options_proto.ParseFromArray(
425 run_options->data, run_options->length)) {
426 status->status = InvalidArgument("Unparseable RunOptions proto");
427 return;
428 }
429 if (run_metadata != nullptr && run_metadata->data != nullptr) {
430 status->status =
431 InvalidArgument("Passing non-empty run_metadata is invalid.");
432 return;
433 }
434
435 RunMetadata run_metadata_proto;
436 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437 target_oper_names, &outputs, &run_metadata_proto);
438
439 // Serialize back to upstream client, who now owns the new buffer
440 if (run_metadata != nullptr) {
441 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442 if (!status->status.ok()) return;
443 }
444 } else {
445 // NOTE(zongheng): PRun does not support RunOptions yet.
446 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447 }
448 if (!result.ok()) {
449 status->status = result;
450 return;
451 }
452
453 // Store results in c_outputs[]
454 for (int i = 0; i < noutputs; ++i) {
455 const Tensor& src = outputs[i];
456 if (!src.IsInitialized() || src.NumElements() == 0) {
457 c_outputs[i] =
458 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459 continue;
460 }
461 c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462 if (!status->status.ok()) return;
463 }
464 }
465
466 extern "C" {
467
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469 // Input tensors
470 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471 // Output tensors
472 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473 // Target nodes
474 const char** c_target_oper_names, int ntargets,
475 TF_Buffer* run_metadata, TF_Status* status) {
476 TF_Run_Setup(noutputs, c_outputs, status);
477 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479 for (int i = 0; i < ninputs; ++i) {
480 input_pairs[i].first = c_input_names[i];
481 }
482 std::vector<string> output_names(noutputs);
483 for (int i = 0; i < noutputs; ++i) {
484 output_names[i] = c_output_names[i];
485 }
486 std::vector<string> target_oper_names(ntargets);
487 for (int i = 0; i < ntargets; ++i) {
488 target_oper_names[i] = c_target_oper_names[i];
489 }
490 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491 c_outputs, target_oper_names, run_metadata, status);
492 }
493
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495 // Input names
496 const char** c_input_names, int ninputs,
497 // Output names
498 const char** c_output_names, int noutputs,
499 // Target nodes
500 const char** c_target_oper_names, int ntargets,
501 const char** handle, TF_Status* status) {
502 *handle = nullptr;
503
504 std::vector<string> input_names(ninputs);
505 std::vector<string> output_names(noutputs);
506 std::vector<string> target_oper_names(ntargets);
507 for (int i = 0; i < ninputs; ++i) {
508 input_names[i] = c_input_names[i];
509 }
510 for (int i = 0; i < noutputs; ++i) {
511 output_names[i] = c_output_names[i];
512 }
513 for (int i = 0; i < ntargets; ++i) {
514 target_oper_names[i] = c_target_oper_names[i];
515 }
516 string new_handle;
517 status->status = s->session->PRunSetup(input_names, output_names,
518 target_oper_names, &new_handle);
519 if (status->status.ok()) {
520 char* buf = new char[new_handle.size() + 1];
521 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522 *handle = buf;
523 }
524 }
525
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527 // Input tensors
528 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529 // Output tensors
530 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531 // Target nodes
532 const char** c_target_oper_names, int ntargets,
533 TF_Status* status) {
534 TF_Run_Setup(noutputs, c_outputs, status);
535 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537 for (int i = 0; i < ninputs; ++i) {
538 input_pairs[i].first = c_input_names[i];
539 }
540
541 std::vector<string> output_names(noutputs);
542 for (int i = 0; i < noutputs; ++i) {
543 output_names[i] = c_output_names[i];
544 }
545 std::vector<string> target_oper_names(ntargets);
546 for (int i = 0; i < ntargets; ++i) {
547 target_oper_names[i] = c_target_oper_names[i];
548 }
549 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550 c_outputs, target_oper_names, nullptr, status);
551 }
552
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554 TF_Library* lib_handle = new TF_Library;
555 status->status = tensorflow::LoadLibrary(
556 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557 &lib_handle->op_list.length);
558 if (!status->status.ok()) {
559 delete lib_handle;
560 return nullptr;
561 }
562 return lib_handle;
563 }
564
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568 if (lib_handle == nullptr) return;
569 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570 delete lib_handle;
571 }
572
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574 std::vector<tensorflow::OpDef> op_defs;
575 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576 tensorflow::OpList op_list;
577 for (const auto& op : op_defs) {
578 *(op_list.add_op()) = op;
579 }
580 TF_Buffer* ret = TF_NewBuffer();
581 TF_CHECK_OK(MessageToBuffer(op_list, ret));
582 return ret;
583 }
584
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591 TF_DeviceList* response = new TF_DeviceList;
592 status->status = session->session->ListDevices(&response->response);
593 return response;
594 }
595
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)596 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
597 TF_Status* status) {
598 TF_DeviceList* response = new TF_DeviceList;
599 status->status = session->session->ListDevices(&response->response);
600 return response;
601 }
602
TF_DeviceListCount(const TF_DeviceList * list)603 int TF_DeviceListCount(const TF_DeviceList* list) {
604 return list->response.size();
605 }
606
607 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
608 return_type method_name(const TF_DeviceList* list, const int index, \
609 TF_Status* status) { \
610 if (list == nullptr) { \
611 status->status = InvalidArgument("list is null!"); \
612 return err_val; \
613 } \
614 if (index < 0 || index >= list->response.size()) { \
615 status->status = InvalidArgument("index out of bounds"); \
616 return err_val; \
617 } \
618 status->status = Status::OK(); \
619 return list->response[index].accessor; \
620 }
621
622 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
623 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
624 nullptr);
625 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
626 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
627
628 #undef TF_DEVICELIST_METHOD
629
630 } // end extern "C"
631
632 // --------------------------------------------------------------------------
633 // New Graph and Session API
634
635 // Helper functions -----------------------------------------------------------
636
637 namespace {
638
ToOperation(Node * node)639 TF_Operation* ToOperation(Node* node) {
640 return static_cast<TF_Operation*>(static_cast<void*>(node));
641 }
642
OutputName(const TF_Output & output)643 string OutputName(const TF_Output& output) {
644 return StrCat(output.oper->node.name(), ":", output.index);
645 }
646
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)647 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
648 const char* attr_name,
649 TF_Status* status) {
650 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
651 if (attr == nullptr) {
652 status->status = InvalidArgument("Operation '", oper->node.name(),
653 "' has no attr named '", attr_name, "'.");
654 }
655 return attr;
656 }
657
ToTensorId(const TF_Output & output)658 TensorId ToTensorId(const TF_Output& output) {
659 return TensorId(output.oper->node.name(), output.index);
660 }
661
662 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)663 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
664 int n) {
665 std::vector<tensorflow::Output> outputs(n);
666 for (int i = 0; i < n; ++i) {
667 outputs[i] =
668 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
669 }
670 return outputs;
671 }
672
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)673 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
674 TF_Output* tf_outputs) {
675 for (int i = 0; i < outputs.size(); i++) {
676 tf_outputs[i].oper = ToOperation(outputs[i].node());
677 tf_outputs[i].index = outputs[i].index();
678 }
679 }
680 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
681
682 } // namespace
683
684 // Shape functions -----------------------------------------------------------
685
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)686 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
687 const int64_t* dims, const int num_dims,
688 TF_Status* status) {
689 Node* node = &output.oper->node;
690
691 mutex_lock l(graph->mu);
692 tensorflow::shape_inference::InferenceContext* ic =
693 graph->refiner.GetContext(node);
694 if (ic == nullptr) {
695 status->status =
696 InvalidArgument("Node ", node->name(), " was not found in the graph");
697 return;
698 }
699 tensorflow::shape_inference::ShapeHandle new_shape =
700 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
701 status->status = graph->refiner.SetShape(node, output.index, new_shape);
702 }
703
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)704 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
705 TF_Status* status) {
706 Node* node = &output.oper->node;
707
708 mutex_lock l(graph->mu);
709 tensorflow::shape_inference::InferenceContext* ic =
710 graph->refiner.GetContext(node);
711 if (ic == nullptr) {
712 status->status =
713 InvalidArgument("Node ", node->name(), " was not found in the graph");
714 return -1;
715 }
716
717 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
718
719 // Unknown rank means the number of dimensions is -1.
720 if (!ic->RankKnown(shape)) {
721 return -1;
722 }
723
724 return ic->Rank(shape);
725 }
726
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)727 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
728 int num_dims, TF_Status* status) {
729 Node* node = &output.oper->node;
730
731 mutex_lock l(graph->mu);
732 tensorflow::shape_inference::InferenceContext* ic =
733 graph->refiner.GetContext(node);
734 if (ic == nullptr) {
735 status->status =
736 InvalidArgument("Node ", node->name(), " was not found in the graph");
737 return;
738 }
739
740 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
741
742 int rank = -1;
743 if (ic->RankKnown(shape)) {
744 rank = ic->Rank(shape);
745 }
746
747 if (num_dims != rank) {
748 status->status = InvalidArgument("Expected rank is ", num_dims,
749 " but actual rank is ", rank);
750 return;
751 }
752
753 if (num_dims == 0) {
754 // Output shape is a scalar.
755 return;
756 }
757
758 // Rank is greater than 0, so fill in the values, if known, and
759 // -1 for unknown values.
760 for (int i = 0; i < num_dims; ++i) {
761 auto dim = ic->Dim(shape, i);
762 tensorflow::int64 value = -1;
763 if (ic->ValueKnown(dim)) {
764 value = ic->Value(dim);
765 }
766 dims[i] = value;
767 }
768 }
769
770 // TF_OperationDescription functions ------------------------------------------
771
772 extern "C" {
773
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)774 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
775 const char* op_type,
776 const char* oper_name)
777 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
778 return new TF_OperationDescription(graph, op_type, oper_name);
779 }
780
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)781 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
782 const char* oper_name) {
783 mutex_lock l(graph->mu);
784 return TF_NewOperationLocked(graph, op_type, oper_name);
785 }
786
TF_SetDevice(TF_OperationDescription * desc,const char * device)787 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
788 desc->node_builder.Device(device);
789 }
790
TF_AddInput(TF_OperationDescription * desc,TF_Output input)791 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
792 desc->node_builder.Input(&input.oper->node, input.index);
793 }
794
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)795 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
796 int num_inputs) {
797 std::vector<NodeBuilder::NodeOut> input_list;
798 input_list.reserve(num_inputs);
799 for (int i = 0; i < num_inputs; ++i) {
800 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
801 }
802 desc->node_builder.Input(input_list);
803 }
804
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)805 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
806 desc->node_builder.ControlInput(&input->node);
807 }
808
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)809 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
810 desc->colocation_constraints.emplace(
811 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
812 }
813
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)814 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
815 const void* value, size_t length) {
816 tensorflow::StringPiece s(static_cast<const char*>(value), length);
817 desc->node_builder.Attr(attr_name, s);
818 }
819
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)820 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
821 const void* const* values, const size_t* lengths,
822 int num_values) {
823 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
824 desc->colocation_constraints.clear();
825 for (int i = 0; i < num_values; ++i) {
826 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
827 lengths[i]);
828 }
829 } else {
830 std::vector<tensorflow::StringPiece> v;
831 v.reserve(num_values);
832 for (int i = 0; i < num_values; ++i) {
833 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
834 }
835 desc->node_builder.Attr(attr_name, v);
836 }
837 }
838
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)839 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
840 int64_t value) {
841 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
842 "64-bit int types should match in size");
843 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
844 }
845
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)846 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
847 const int64_t* values, int num_values) {
848 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
849 "64-bit int types should match in size");
850 desc->node_builder.Attr(
851 attr_name,
852 ArraySlice<const tensorflow::int64>(
853 reinterpret_cast<const tensorflow::int64*>(values), num_values));
854 }
855
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)856 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
857 float value) {
858 desc->node_builder.Attr(attr_name, value);
859 }
860
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)861 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
862 const float* values, int num_values) {
863 desc->node_builder.Attr(attr_name,
864 ArraySlice<const float>(values, num_values));
865 }
866
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)867 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
868 unsigned char value) {
869 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
870 }
871
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)872 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
873 const unsigned char* values, int num_values) {
874 std::unique_ptr<bool[]> b(new bool[num_values]);
875 for (int i = 0; i < num_values; ++i) {
876 b[i] = values[i];
877 }
878 desc->node_builder.Attr(attr_name,
879 ArraySlice<const bool>(b.get(), num_values));
880 }
881
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)882 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
883 TF_DataType value) {
884 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
885 }
886
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)887 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
888 const TF_DataType* values, int num_values) {
889 desc->node_builder.Attr(
890 attr_name, ArraySlice<const DataType>(
891 reinterpret_cast<const DataType*>(values), num_values));
892 }
893
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)894 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
895 const char* placeholder) {
896 tensorflow::AttrValue attr_value;
897 attr_value.set_placeholder(placeholder);
898 desc->node_builder.Attr(attr_name, attr_value);
899 }
900
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)901 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
902 const char* value, size_t length) {
903 tensorflow::NameAttrList func_name;
904 func_name.set_name(string(value, value + length));
905 desc->node_builder.Attr(attr_name, func_name);
906 }
907
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)908 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
909 const int64_t* dims, int num_dims) {
910 PartialTensorShape shape;
911 if (num_dims >= 0) {
912 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
913 "64-bit int types should match in size");
914 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
915 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
916 }
917 desc->node_builder.Attr(attr_name, shape);
918 }
919
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)920 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
921 const int64_t* const* dims, const int* num_dims,
922 int num_shapes) {
923 std::vector<PartialTensorShape> shapes;
924 shapes.reserve(num_shapes);
925 for (int i = 0; i < num_shapes; ++i) {
926 if (num_dims[i] < 0) {
927 shapes.emplace_back();
928 } else {
929 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
930 "64-bit int types should match in size");
931 shapes.emplace_back(ArraySlice<tensorflow::int64>(
932 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
933 }
934 }
935 desc->node_builder.Attr(attr_name, shapes);
936 }
937
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)938 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
939 const char* attr_name, const void* proto,
940 size_t proto_len, TF_Status* status) {
941 // shape.ParseFromArray takes an int as length, this function takes size_t,
942 // make sure there is no information loss.
943 if (proto_len > std::numeric_limits<int>::max()) {
944 status->status = InvalidArgument(
945 "proto_len (", proto_len,
946 " bytes) is too large to be parsed by the protocol buffer library");
947 return;
948 }
949 TensorShapeProto shape;
950 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
951 desc->node_builder.Attr(attr_name, shape);
952 status->status = Status::OK();
953 } else {
954 status->status = InvalidArgument("Unparseable TensorShapeProto");
955 }
956 }
957
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)958 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
959 const char* attr_name,
960 const void* const* protos,
961 const size_t* proto_lens, int num_shapes,
962 TF_Status* status) {
963 std::vector<TensorShapeProto> shapes;
964 shapes.resize(num_shapes);
965 for (int i = 0; i < num_shapes; ++i) {
966 if (proto_lens[i] > std::numeric_limits<int>::max()) {
967 status->status = InvalidArgument(
968 "length of element ", i, " in the list (", proto_lens[i],
969 " bytes) is too large to be parsed by the protocol buffer library");
970 return;
971 }
972 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
973 status->status =
974 InvalidArgument("Unparseable TensorShapeProto at index ", i);
975 return;
976 }
977 }
978 desc->node_builder.Attr(attr_name, shapes);
979 status->status = Status::OK();
980 }
981
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)982 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
983 TF_Tensor* value, TF_Status* status) {
984 Tensor t;
985 status->status = TF_TensorToTensor(value, &t);
986 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
987 }
988
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)989 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
990 TF_Tensor* const* values, int num_values,
991 TF_Status* status) {
992 status->status = Status::OK();
993 std::vector<Tensor> t;
994 t.reserve(num_values);
995
996 for (int i = 0; i < num_values && status->status.ok(); ++i) {
997 Tensor v;
998 status->status = TF_TensorToTensor(values[i], &v);
999 t.emplace_back(v);
1000 }
1001
1002 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1003 }
1004
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1005 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1006 const void* proto, size_t proto_len,
1007 TF_Status* status) {
1008 tensorflow::AttrValue attr_value;
1009 if (!attr_value.ParseFromArray(proto, proto_len)) {
1010 status->status = InvalidArgument("Unparseable AttrValue proto");
1011 return;
1012 }
1013
1014 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1015 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1016 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1017 status->status =
1018 InvalidArgument("Expected \"list\" field for \"",
1019 tensorflow::kColocationAttrName, "\" attribute");
1020 return;
1021 }
1022 desc->colocation_constraints.clear();
1023 for (const string& location : attr_value.list().s()) {
1024 desc->colocation_constraints.insert(location);
1025 }
1026 } else {
1027 desc->node_builder.Attr(attr_name, std::move(attr_value));
1028 }
1029
1030 status->status = Status::OK();
1031 }
1032
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1033 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1034 TF_Status* status)
1035 EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1036 Node* ret = nullptr;
1037
1038 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1039 status->status = InvalidArgument("Duplicate node name in graph: '",
1040 desc->node_builder.node_name(), "'");
1041 } else {
1042 if (!desc->colocation_constraints.empty()) {
1043 desc->node_builder.Attr(
1044 tensorflow::kColocationAttrName,
1045 std::vector<string>(desc->colocation_constraints.begin(),
1046 desc->colocation_constraints.end()));
1047 }
1048 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1049 /*consume=*/true);
1050
1051 if (status->status.ok()) {
1052 // Run shape inference function for newly added node.
1053 status->status = desc->graph->refiner.AddNode(ret);
1054 }
1055 if (status->status.ok()) {
1056 // Add the node to the name-to-node mapping.
1057 desc->graph->name_map[ret->name()] = ret;
1058 } else if (ret != nullptr) {
1059 desc->graph->graph.RemoveNode(ret);
1060 ret = nullptr;
1061 }
1062 }
1063
1064 delete desc;
1065
1066 return ToOperation(ret);
1067 }
1068
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1069 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1070 TF_Status* status) {
1071 mutex_lock l(desc->graph->mu);
1072 return TF_FinishOperationLocked(desc, status);
1073 }
1074
1075 // TF_Operation functions
1076 // ----------------------------------------------------------
1077
TF_OperationName(TF_Operation * oper)1078 const char* TF_OperationName(TF_Operation* oper) {
1079 return oper->node.name().c_str();
1080 }
1081
TF_OperationOpType(TF_Operation * oper)1082 const char* TF_OperationOpType(TF_Operation* oper) {
1083 return oper->node.type_string().c_str();
1084 }
1085
TF_OperationDevice(TF_Operation * oper)1086 const char* TF_OperationDevice(TF_Operation* oper) {
1087 return oper->node.requested_device().c_str();
1088 }
1089
TF_OperationNumOutputs(TF_Operation * oper)1090 int TF_OperationNumOutputs(TF_Operation* oper) {
1091 return oper->node.num_outputs();
1092 }
1093
TF_OperationOutputType(TF_Output oper_out)1094 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1095 return static_cast<TF_DataType>(
1096 oper_out.oper->node.output_type(oper_out.index));
1097 }
1098
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1099 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1100 TF_Status* status) {
1101 NameRangeMap name_ranges;
1102 status->status =
1103 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1104 if (!status->status.ok()) return -1;
1105 auto iter = name_ranges.find(arg_name);
1106 if (iter == name_ranges.end()) {
1107 status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1108 return -1;
1109 }
1110 return iter->second.second - iter->second.first;
1111 }
1112
TF_OperationNumInputs(TF_Operation * oper)1113 int TF_OperationNumInputs(TF_Operation* oper) {
1114 return oper->node.num_inputs();
1115 }
1116
TF_OperationInputType(TF_Input oper_in)1117 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1118 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1119 }
1120
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1121 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1122 TF_Status* status) {
1123 NameRangeMap name_ranges;
1124 status->status =
1125 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1126 if (!status->status.ok()) return -1;
1127 auto iter = name_ranges.find(arg_name);
1128 if (iter == name_ranges.end()) {
1129 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1130 return -1;
1131 }
1132 return iter->second.second - iter->second.first;
1133 }
1134
TF_OperationInput(TF_Input oper_in)1135 TF_Output TF_OperationInput(TF_Input oper_in) {
1136 const tensorflow::Edge* edge;
1137 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1138 if (!s.ok()) {
1139 return {nullptr, -1};
1140 }
1141
1142 return {ToOperation(edge->src()), edge->src_output()};
1143 }
1144
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1145 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1146 int max_inputs) {
1147 for (auto* edge : oper->node.in_edges()) {
1148 if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1149 inputs[edge->dst_input()] = {ToOperation(edge->src()),
1150 edge->src_output()};
1151 }
1152 }
1153 }
1154
TF_OperationOutputNumConsumers(TF_Output oper_out)1155 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1156 int count = 0;
1157 for (const auto* edge : oper_out.oper->node.out_edges()) {
1158 if (edge->src_output() == oper_out.index) {
1159 ++count;
1160 }
1161 }
1162 return count;
1163 }
1164
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1165 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1166 int max_consumers) {
1167 int count = 0;
1168 for (const auto* edge : oper_out.oper->node.out_edges()) {
1169 if (edge->src_output() == oper_out.index) {
1170 if (count < max_consumers) {
1171 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1172 }
1173 ++count;
1174 }
1175 }
1176 return count;
1177 }
1178
TF_OperationNumControlInputs(TF_Operation * oper)1179 int TF_OperationNumControlInputs(TF_Operation* oper) {
1180 int count = 0;
1181 for (const auto* edge : oper->node.in_edges()) {
1182 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1183 ++count;
1184 }
1185 }
1186 return count;
1187 }
1188
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1189 int TF_OperationGetControlInputs(TF_Operation* oper,
1190 TF_Operation** control_inputs,
1191 int max_control_inputs) {
1192 int count = 0;
1193 for (const auto* edge : oper->node.in_edges()) {
1194 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1195 if (count < max_control_inputs) {
1196 control_inputs[count] = ToOperation(edge->src());
1197 }
1198 ++count;
1199 }
1200 }
1201 return count;
1202 }
1203
TF_OperationNumControlOutputs(TF_Operation * oper)1204 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1205 int count = 0;
1206 for (const auto* edge : oper->node.out_edges()) {
1207 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1208 ++count;
1209 }
1210 }
1211 return count;
1212 }
1213
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1214 int TF_OperationGetControlOutputs(TF_Operation* oper,
1215 TF_Operation** control_outputs,
1216 int max_control_outputs) {
1217 int count = 0;
1218 for (const auto* edge : oper->node.out_edges()) {
1219 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1220 if (count < max_control_outputs) {
1221 control_outputs[count] = ToOperation(edge->dst());
1222 }
1223 ++count;
1224 }
1225 }
1226 return count;
1227 }
1228
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1229 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1230 const char* attr_name,
1231 TF_Status* status) {
1232 TF_AttrMetadata metadata;
1233 const auto* attr = GetAttrValue(oper, attr_name, status);
1234 if (!status->status.ok()) return metadata;
1235 switch (attr->value_case()) {
1236 #define SINGLE_CASE(kK, attr_type, size_expr) \
1237 case tensorflow::AttrValue::kK: \
1238 metadata.is_list = 0; \
1239 metadata.list_size = -1; \
1240 metadata.type = attr_type; \
1241 metadata.total_size = size_expr; \
1242 break;
1243
1244 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1245 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1246 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1247 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1248 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1249 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1250 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1251 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1252 #undef SINGLE_CASE
1253
1254 case tensorflow::AttrValue::kList:
1255 metadata.is_list = 1;
1256 metadata.list_size = 0;
1257 metadata.total_size = -1;
1258 #define LIST_CASE(field, attr_type, ...) \
1259 if (attr->list().field##_size() > 0) { \
1260 metadata.type = attr_type; \
1261 metadata.list_size = attr->list().field##_size(); \
1262 __VA_ARGS__; \
1263 break; \
1264 }
1265
1266 LIST_CASE(
1267 s, TF_ATTR_STRING, metadata.total_size = 0;
1268 for (int i = 0; i < attr->list().s_size();
1269 ++i) { metadata.total_size += attr->list().s(i).size(); });
1270 LIST_CASE(i, TF_ATTR_INT);
1271 LIST_CASE(f, TF_ATTR_FLOAT);
1272 LIST_CASE(b, TF_ATTR_BOOL);
1273 LIST_CASE(type, TF_ATTR_TYPE);
1274 LIST_CASE(
1275 shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1276 for (int i = 0; i < attr->list().shape_size(); ++i) {
1277 const auto& s = attr->list().shape(i);
1278 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1279 });
1280 LIST_CASE(tensor, TF_ATTR_TENSOR);
1281 LIST_CASE(tensor, TF_ATTR_FUNC);
1282 #undef LIST_CASE
1283 // All lists empty, determine the type from the OpDef.
1284 if (metadata.list_size == 0) {
1285 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1286 const auto& a = oper->node.op_def().attr(i);
1287 if (a.name() != attr_name) continue;
1288 const string& typestr = a.type();
1289 if (typestr == "list(string)") {
1290 metadata.type = TF_ATTR_STRING;
1291 } else if (typestr == "list(int)") {
1292 metadata.type = TF_ATTR_INT;
1293 } else if (typestr == "list(float)") {
1294 metadata.type = TF_ATTR_FLOAT;
1295 } else if (typestr == "list(bool)") {
1296 metadata.type = TF_ATTR_BOOL;
1297 } else if (typestr == "list(type)") {
1298 metadata.type = TF_ATTR_TYPE;
1299 } else if (typestr == "list(shape)") {
1300 metadata.type = TF_ATTR_SHAPE;
1301 } else if (typestr == "list(tensor)") {
1302 metadata.type = TF_ATTR_TENSOR;
1303 } else if (typestr == "list(func)") {
1304 metadata.type = TF_ATTR_FUNC;
1305 } else {
1306 status->status = InvalidArgument(
1307 "Attribute '", attr_name,
1308 "' has an empty value of an unrecognized type '", typestr, "'");
1309 return metadata;
1310 }
1311 }
1312 }
1313 break;
1314
1315 case tensorflow::AttrValue::kPlaceholder:
1316 metadata.is_list = 0;
1317 metadata.list_size = -1;
1318 metadata.type = TF_ATTR_PLACEHOLDER;
1319 metadata.total_size = -1;
1320 break;
1321
1322 case tensorflow::AttrValue::kFunc:
1323 metadata.is_list = 0;
1324 metadata.list_size = -1;
1325 metadata.type = TF_ATTR_FUNC;
1326 metadata.total_size = -1;
1327 break;
1328
1329 case tensorflow::AttrValue::VALUE_NOT_SET:
1330 status->status =
1331 InvalidArgument("Attribute '", attr_name, "' has no value set");
1332 break;
1333 }
1334 return metadata;
1335 }
1336
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1337 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1338 void* value, size_t max_length,
1339 TF_Status* status) {
1340 const auto* attr = GetAttrValue(oper, attr_name, status);
1341 if (!status->status.ok()) return;
1342 if (attr->value_case() != tensorflow::AttrValue::kS) {
1343 status->status =
1344 InvalidArgument("Attribute '", attr_name, "' is not a string");
1345 return;
1346 }
1347 if (max_length <= 0) {
1348 return;
1349 }
1350 const auto& s = attr->s();
1351 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1352 }
1353
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1354 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1355 void** values, size_t* lengths,
1356 int max_values, void* storage,
1357 size_t storage_size, TF_Status* status) {
1358 const auto* attr = GetAttrValue(oper, attr_name, status);
1359 if (!status->status.ok()) return;
1360 if (attr->value_case() != tensorflow::AttrValue::kList) {
1361 status->status =
1362 InvalidArgument("Value for '", attr_name, "' is not a list");
1363 return;
1364 }
1365 const auto len = std::min(max_values, attr->list().s_size());
1366 char* p = static_cast<char*>(storage);
1367 for (int i = 0; i < len; ++i) {
1368 const string& s = attr->list().s(i);
1369 values[i] = p;
1370 lengths[i] = s.size();
1371 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1372 status->status = InvalidArgument(
1373 "Not enough storage to hold the requested list of strings");
1374 return;
1375 }
1376 memcpy(values[i], s.data(), s.size());
1377 p += s.size();
1378 }
1379 }
1380
1381 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1382 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1383 TF_Status* status) { \
1384 cpp_type v; \
1385 status->status = \
1386 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1387 *value = static_cast<c_type>(v); \
1388 } \
1389 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1390 int max_values, TF_Status* status) { \
1391 const auto* attr = GetAttrValue(oper, attr_name, status); \
1392 if (!status->status.ok()) return; \
1393 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1394 status->status = \
1395 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1396 return; \
1397 } \
1398 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1399 for (int i = 0; i < len; ++i) { \
1400 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1401 } \
1402 }
1403 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1404 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1405 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1406 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1407 #undef DEFINE_GETATTR
1408
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1409 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1410 int64_t* value, int num_dims, TF_Status* status) {
1411 PartialTensorShape shape;
1412 status->status =
1413 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1414 if (!status->status.ok()) return;
1415 auto len = std::min(shape.dims(), num_dims);
1416 for (int i = 0; i < len; ++i) {
1417 value[i] = shape.dim_size(i);
1418 }
1419 }
1420
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1421 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1422 int64_t** dims, int* num_dims, int num_shapes,
1423 int64_t* storage, int storage_size,
1424 TF_Status* status) {
1425 std::vector<PartialTensorShape> shapes;
1426 status->status =
1427 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1428 if (!status->status.ok()) return;
1429 auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1430 int64_t* p = storage;
1431 int storage_left = storage_size;
1432 for (int i = 0; i < len; ++i) {
1433 // shapes[i].dims() == -1 for shapes with an unknown rank.
1434 int64_t n = shapes[i].dims();
1435 num_dims[i] = n;
1436 dims[i] = p;
1437 if (n < 0) {
1438 continue;
1439 }
1440 if (storage_left < n) {
1441 status->status = InvalidArgument(
1442 "Not enough storage to hold the requested list of shapes");
1443 return;
1444 }
1445 storage_left -= n;
1446 for (int j = 0; j < n; ++j, ++p) {
1447 *p = shapes[i].dim_size(j);
1448 }
1449 }
1450 }
1451
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1452 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1453 const char* attr_name,
1454 TF_Buffer* value, TF_Status* status) {
1455 const auto* attr = GetAttrValue(oper, attr_name, status);
1456 if (!status->status.ok()) return;
1457 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1458 status->status =
1459 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1460 return;
1461 }
1462 status->status = MessageToBuffer(attr->shape(), value);
1463 }
1464
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1465 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1466 const char* attr_name,
1467 TF_Buffer** values, int max_values,
1468 TF_Status* status) {
1469 const auto* attr = GetAttrValue(oper, attr_name, status);
1470 if (!status->status.ok()) return;
1471 if (attr->value_case() != tensorflow::AttrValue::kList) {
1472 status->status =
1473 InvalidArgument("Value for '", attr_name, "' is not a list");
1474 return;
1475 }
1476 const auto len = std::min(max_values, attr->list().shape_size());
1477 for (int i = 0; i < len; ++i) {
1478 values[i] = TF_NewBuffer();
1479 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1480 if (!status->status.ok()) {
1481 // Delete everything allocated to far, the operation has failed.
1482 for (int j = 0; j <= i; ++j) {
1483 TF_DeleteBuffer(values[j]);
1484 }
1485 return;
1486 }
1487 }
1488 }
1489
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1490 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1491 TF_Tensor** value, TF_Status* status) {
1492 *value = nullptr;
1493 Tensor t;
1494 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1495 if (!status->status.ok()) return;
1496 *value = TF_TensorFromTensor(t, &status->status);
1497 }
1498
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1499 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1500 TF_Tensor** values, int max_values,
1501 TF_Status* status) {
1502 std::vector<Tensor> ts;
1503 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1504 if (!status->status.ok()) return;
1505 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1506 for (int i = 0; i < len; ++i) {
1507 values[i] = TF_TensorFromTensor(ts[i], &status->status);
1508 }
1509 }
1510
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1511 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1512 TF_Buffer* output_attr_value,
1513 TF_Status* status) {
1514 const auto* attr = GetAttrValue(oper, attr_name, status);
1515 if (!status->status.ok()) return;
1516 status->status = MessageToBuffer(*attr, output_attr_value);
1517 }
1518
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1519 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1520 TF_Status* status) {
1521 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1522 }
1523
1524 // TF_Graph functions ---------------------------------------------------------
1525
TF_Graph()1526 TF_Graph::TF_Graph()
1527 : graph(tensorflow::OpRegistry::Global()),
1528 refiner(graph.versions().producer(), graph.op_registry()),
1529 delete_requested(false),
1530 parent(nullptr),
1531 parent_inputs(nullptr) {
1532 // Tell the shape refiner to also run shape inference on functions.
1533 refiner.set_function_library_for_shape_inference(&graph.flib_def());
1534 }
1535
TF_NewGraph()1536 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1537
TF_DeleteGraph(TF_Graph * g)1538 void TF_DeleteGraph(TF_Graph* g) {
1539 if (g == nullptr) return;
1540 g->mu.lock();
1541 g->delete_requested = true;
1542 const bool del = g->sessions.empty();
1543 g->mu.unlock();
1544 if (del) delete g;
1545 }
1546
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1547 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1548 mutex_lock l(graph->mu);
1549 auto iter = graph->name_map.find(oper_name);
1550 if (iter == graph->name_map.end()) {
1551 return nullptr;
1552 } else {
1553 return ToOperation(iter->second);
1554 }
1555 }
1556
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1557 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1558 if (*pos == 0) {
1559 // Advance past the first sentinel nodes in every graph (the source & sink).
1560 *pos += 2;
1561 } else {
1562 // Advance to the next node.
1563 *pos += 1;
1564 }
1565
1566 mutex_lock l(graph->mu);
1567 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1568 Node* node = graph->graph.FindNodeId(*pos);
1569 // FindNodeId() returns nullptr for nodes that have been deleted.
1570 // We aren't currently allowing nodes to be deleted, but it is safer
1571 // to still check.
1572 if (node != nullptr) return ToOperation(node);
1573 *pos += 1;
1574 }
1575
1576 // No more nodes.
1577 return nullptr;
1578 }
1579
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1580 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1581 TF_Status* status) {
1582 GraphDef def;
1583 {
1584 mutex_lock l(graph->mu);
1585 graph->graph.ToGraphDef(&def);
1586 }
1587 status->status = MessageToBuffer(def, output_graph_def);
1588 }
1589
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1590 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1591 TF_Buffer* output_op_def, TF_Status* status) {
1592 const OpDef* op_def;
1593 {
1594 mutex_lock l(graph->mu);
1595 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1596 if (!status->status.ok()) return;
1597 }
1598 status->status = MessageToBuffer(*op_def, output_op_def);
1599 }
1600
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1601 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1602 TF_Status* status) {
1603 VersionDef versions;
1604 {
1605 mutex_lock l(graph->mu);
1606 versions = graph->graph.versions();
1607 }
1608 status->status = MessageToBuffer(versions, output_version_def);
1609 }
1610
TF_NewImportGraphDefOptions()1611 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1612 return new TF_ImportGraphDefOptions;
1613 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1614 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1615 delete opts;
1616 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1617 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1618 const char* prefix) {
1619 opts->opts.prefix = prefix;
1620 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1621 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1622 const char* device) {
1623 opts->opts.default_device = device;
1624 }
1625
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1626 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1627 unsigned char uniquify_names) {
1628 opts->opts.uniquify_names = uniquify_names;
1629 }
1630
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1631 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1632 unsigned char uniquify_prefix) {
1633 opts->opts.uniquify_prefix = uniquify_prefix;
1634 }
1635
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1636 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1637 const char* src_name,
1638 int src_index, TF_Output dst) {
1639 opts->tensor_id_data.push_back(src_name);
1640 const string& src_name_str = opts->tensor_id_data.back();
1641 // We don't need to store dst's name in tensor_id_data, since `dst` must
1642 // outlive the ImportGraphDef call.
1643 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1644 }
1645
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1646 void TF_ImportGraphDefOptionsRemapControlDependency(
1647 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1648 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1649 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1650 }
1651
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1652 extern void TF_ImportGraphDefOptionsAddControlDependency(
1653 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1654 opts->opts.control_dependencies.push_back(oper->node.name());
1655 }
1656
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1657 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1658 const char* oper_name, int index) {
1659 opts->tensor_id_data.push_back(oper_name);
1660 const string& oper_name_str = opts->tensor_id_data.back();
1661 opts->opts.return_tensors.emplace_back(oper_name_str, index);
1662 }
1663
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1664 int TF_ImportGraphDefOptionsNumReturnOutputs(
1665 const TF_ImportGraphDefOptions* opts) {
1666 return opts->opts.return_tensors.size();
1667 }
1668
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1669 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1670 const char* oper_name) {
1671 opts->opts.return_nodes.push_back(oper_name);
1672 }
1673
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1674 int TF_ImportGraphDefOptionsNumReturnOperations(
1675 const TF_ImportGraphDefOptions* opts) {
1676 return opts->opts.return_nodes.size();
1677 }
1678
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1679 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1680 int* num_outputs,
1681 TF_Output** outputs) {
1682 *num_outputs = results->return_tensors.size();
1683 *outputs = results->return_tensors.data();
1684 }
1685
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1686 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1687 int* num_opers,
1688 TF_Operation*** opers) {
1689 *num_opers = results->return_nodes.size();
1690 *opers = results->return_nodes.data();
1691 }
1692
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1693 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1694 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1695 const char*** src_names, int** src_indexes) {
1696 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1697 *src_names = results->missing_unused_key_names.data();
1698 *src_indexes = results->missing_unused_key_indexes.data();
1699 }
1700
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1701 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1702 delete results;
1703 }
1704
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1705 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1706 const TF_ImportGraphDefOptions* opts,
1707 TF_ImportGraphDefResults* tf_results,
1708 TF_Status* status)
1709 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1710 const int last_node_id = graph->graph.num_node_ids();
1711 tensorflow::ImportGraphDefResults results;
1712 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1713 &graph->refiner, &results);
1714 if (!status->status.ok()) return;
1715
1716 // Add new nodes to name_map
1717 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1718 auto* node = graph->graph.FindNodeId(i);
1719 if (node != nullptr) graph->name_map[node->name()] = node;
1720 }
1721
1722 // Populate return_tensors
1723 DCHECK(tf_results->return_tensors.empty());
1724 tf_results->return_tensors.resize(results.return_tensors.size());
1725 for (int i = 0; i < results.return_tensors.size(); ++i) {
1726 tf_results->return_tensors[i].oper =
1727 ToOperation(results.return_tensors[i].first);
1728 tf_results->return_tensors[i].index = results.return_tensors[i].second;
1729 }
1730
1731 // Populate return_nodes
1732 DCHECK(tf_results->return_nodes.empty());
1733 tf_results->return_nodes.resize(results.return_nodes.size());
1734 for (int i = 0; i < results.return_nodes.size(); ++i) {
1735 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1736 }
1737
1738 // Populate missing unused map keys
1739 DCHECK(tf_results->missing_unused_key_names.empty());
1740 DCHECK(tf_results->missing_unused_key_indexes.empty());
1741 DCHECK(tf_results->missing_unused_key_names_data.empty());
1742
1743 size_t size = results.missing_unused_input_map_keys.size();
1744 tf_results->missing_unused_key_names.resize(size);
1745 tf_results->missing_unused_key_indexes.resize(size);
1746
1747 for (int i = 0; i < size; ++i) {
1748 TensorId id = results.missing_unused_input_map_keys[i];
1749 tf_results->missing_unused_key_names_data.emplace_back(id.first);
1750 tf_results->missing_unused_key_names[i] =
1751 tf_results->missing_unused_key_names_data.back().c_str();
1752 tf_results->missing_unused_key_indexes[i] = id.second;
1753 }
1754 }
1755
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1756 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1757 TF_Graph* graph, const TF_Buffer* graph_def,
1758 const TF_ImportGraphDefOptions* options, TF_Status* status) {
1759 GraphDef def;
1760 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1761 graph_def->length)) {
1762 status->status = InvalidArgument("Invalid GraphDef");
1763 return nullptr;
1764 }
1765 auto results = new TF_ImportGraphDefResults();
1766 mutex_lock l(graph->mu);
1767 GraphImportGraphDefLocked(graph, def, options, results, status);
1768 if (!status->status.ok()) {
1769 delete results;
1770 return nullptr;
1771 }
1772 return results;
1773 }
1774
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1775 void TF_GraphImportGraphDefWithReturnOutputs(
1776 TF_Graph* graph, const TF_Buffer* graph_def,
1777 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1778 int num_return_outputs, TF_Status* status) {
1779 if (num_return_outputs != options->opts.return_tensors.size()) {
1780 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1781 options->opts.return_tensors.size(),
1782 ", got ", num_return_outputs);
1783 return;
1784 }
1785 if (num_return_outputs > 0 && return_outputs == nullptr) {
1786 status->status = InvalidArgument(
1787 "'return_outputs' must be preallocated to length ", num_return_outputs);
1788 return;
1789 }
1790 GraphDef def;
1791 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1792 graph_def->length)) {
1793 status->status = InvalidArgument("Invalid GraphDef");
1794 return;
1795 }
1796 TF_ImportGraphDefResults results;
1797 mutex_lock l(graph->mu);
1798 GraphImportGraphDefLocked(graph, def, options, &results, status);
1799 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1800 memcpy(return_outputs, results.return_tensors.data(),
1801 num_return_outputs * sizeof(TF_Output));
1802 }
1803
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1804 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1805 const TF_ImportGraphDefOptions* options,
1806 TF_Status* status) {
1807 TF_ImportGraphDefResults* results =
1808 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1809 TF_DeleteImportGraphDefResults(results);
1810 }
1811
1812 // While loop functions -------------------------------------------------------
1813
1814 namespace {
1815
1816 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1817
1818 // Creates a placeholder representing an input to the cond or body graph.
1819 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1820 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1821 TF_Output* input, TF_Status* status) {
1822 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1823 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1824 // TODO(skyewm): set placeholder shape
1825 TF_Operation* oper = TF_FinishOperation(desc, status);
1826 if (!status->status.ok()) return false;
1827 *input = {oper, 0};
1828 return true;
1829 }
1830
1831 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1832 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
1833 // will be prepended to copied node names. `control_deps` are nodes in
1834 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1835 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1836 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1837 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1838 tensorflow::ShapeRefiner* dst_refiner,
1839 const TF_Output* src_inputs,
1840 const std::vector<tensorflow::Output>& dst_inputs,
1841 const string& prefix,
1842 const std::vector<tensorflow::Operation>& control_deps,
1843 const TF_Output* nodes_to_return, int nreturn_nodes,
1844 std::vector<tensorflow::Output>* return_nodes) {
1845 DCHECK(return_nodes != nullptr);
1846 GraphDef gdef;
1847 src_graph->ToGraphDef(&gdef);
1848
1849 tensorflow::ImportGraphDefOptions opts;
1850 opts.prefix = prefix;
1851
1852 for (int i = 0; i < dst_inputs.size(); ++i) {
1853 opts.input_map[ToTensorId(src_inputs[i])] =
1854 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1855 }
1856 opts.skip_mapped_nodes = true;
1857
1858 for (const tensorflow::Operation& op : control_deps) {
1859 opts.control_dependencies.push_back(op.node()->name());
1860 }
1861
1862 for (int i = 0; i < nreturn_nodes; ++i) {
1863 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1864 }
1865
1866 // TODO(skyewm): change to OutputTensor
1867 tensorflow::ImportGraphDefResults results;
1868 TF_RETURN_IF_ERROR(
1869 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1870
1871 for (const auto& pair : results.return_tensors) {
1872 return_nodes->emplace_back(pair.first, pair.second);
1873 }
1874 return Status::OK();
1875 }
1876
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1877 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1878 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1879 params.cond_graph->parent == nullptr ||
1880 params.cond_graph->parent != params.body_graph->parent ||
1881 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1882 params.ninputs <= 0 || params.cond_inputs == nullptr ||
1883 params.body_inputs == nullptr || params.body_outputs == nullptr) {
1884 s->status = InvalidArgument(
1885 "TF_WhileParams must be created by successful TF_NewWhile() call");
1886 return false;
1887 }
1888 return true;
1889 }
1890
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1891 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1892 if (params.cond_output.oper == nullptr) {
1893 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1894 return false;
1895 }
1896 for (int i = 0; i < params.ninputs; ++i) {
1897 if (params.body_outputs[i].oper == nullptr) {
1898 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1899 "field isn't set");
1900 return false;
1901 }
1902 }
1903 if (params.name == nullptr) {
1904 s->status = InvalidArgument("TF_WhileParams `name` field is null");
1905 return false;
1906 }
1907 return true;
1908 }
1909
1910 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1911
FreeWhileResources(const TF_WhileParams * params)1912 void FreeWhileResources(const TF_WhileParams* params) {
1913 TF_DeleteGraph(params->cond_graph);
1914 TF_DeleteGraph(params->body_graph);
1915 delete[] params->cond_inputs;
1916 delete[] params->body_inputs;
1917 delete[] params->body_outputs;
1918 }
1919
EmptyWhileParams()1920 TF_WhileParams EmptyWhileParams() {
1921 return {0, nullptr, nullptr, {nullptr, 0},
1922 nullptr, nullptr, nullptr, nullptr};
1923 }
1924
1925 } // namespace
1926
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1927 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1928 TF_Status* status) {
1929 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1930 status->status = tensorflow::errors::Unimplemented(
1931 "Creating while loops is not supported on mobile. File a bug at "
1932 "https://github.com/tensorflow/tensorflow/issues if this feature is "
1933 "important to you");
1934 return EmptyWhileParams();
1935 #else
1936 if (ninputs == 0) {
1937 status->status =
1938 InvalidArgument("TF_NewWhile() must be passed at least one input");
1939 return EmptyWhileParams();
1940 }
1941
1942 TF_Graph* cond_graph = TF_NewGraph();
1943 TF_Graph* body_graph = TF_NewGraph();
1944 cond_graph->parent = g;
1945 cond_graph->parent_inputs = inputs;
1946 body_graph->parent = g;
1947 body_graph->parent_inputs = inputs;
1948
1949 TF_Output* cond_inputs = new TF_Output[ninputs];
1950 TF_Output cond_output = {nullptr, -1};
1951 TF_Output* body_inputs = new TF_Output[ninputs];
1952 TF_Output* body_outputs = new TF_Output[ninputs];
1953 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1954 const char* name = nullptr;
1955
1956 for (int i = 0; i < ninputs; ++i) {
1957 // TODO(skyewm): prefix names with underscore (requires some plumbing)
1958 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1959 &cond_inputs[i], status)) {
1960 break;
1961 }
1962 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1963 &body_inputs[i], status)) {
1964 break;
1965 }
1966 }
1967
1968 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
1969 body_graph, body_inputs, body_outputs, name};
1970
1971 if (!status->status.ok()) {
1972 FreeWhileResources(¶ms);
1973 return EmptyWhileParams();
1974 }
1975 return params;
1976 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1977 }
1978
1979 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1980 namespace {
1981
1982 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)1983 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
1984 TF_Output* outputs) {
1985 if (!ValidateInputWhileParams(*params, status)) return;
1986
1987 TF_Graph* parent = params->cond_graph->parent;
1988 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
1989 int num_loop_vars = params->ninputs;
1990
1991 mutex_lock l(parent->mu);
1992
1993 // 'cond_fn' copies the cond graph into the parent graph.
1994 tensorflow::ops::CondGraphBuilderFn cond_fn =
1995 [params, parent](const tensorflow::Scope& scope,
1996 const std::vector<tensorflow::Output>& inputs,
1997 tensorflow::Output* output) {
1998 DCHECK_EQ(scope.graph(), &parent->graph);
1999 std::vector<tensorflow::Output> cond_output;
2000 TF_RETURN_IF_ERROR(CopyGraph(
2001 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2002 params->cond_inputs, inputs, scope.impl()->name(),
2003 scope.impl()->control_deps(), ¶ms->cond_output,
2004 /* nreturn_nodes */ 1, &cond_output));
2005 *output = cond_output[0];
2006 return Status::OK();
2007 };
2008
2009 // 'body_fn' copies the body graph into the parent graph.
2010 tensorflow::ops::BodyGraphBuilderFn body_fn =
2011 [params, parent, num_loop_vars](
2012 const tensorflow::Scope& scope,
2013 const std::vector<tensorflow::Output>& inputs,
2014 std::vector<tensorflow::Output>* outputs) {
2015 DCHECK_EQ(scope.graph(), &parent->graph);
2016 TF_RETURN_IF_ERROR(
2017 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2018 &parent->refiner, params->body_inputs, inputs,
2019 scope.impl()->name(), scope.impl()->control_deps(),
2020 params->body_outputs, num_loop_vars, outputs));
2021 return Status::OK();
2022 };
2023
2024 // Create the while loop using an internal scope.
2025 tensorflow::Scope scope =
2026 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2027 .NewSubScope(params->name);
2028
2029 const int first_new_node_id = parent->graph.num_node_ids();
2030
2031 tensorflow::OutputList loop_outputs;
2032 status->status = tensorflow::ops::BuildWhileLoop(
2033 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2034 body_fn, params->name, &loop_outputs);
2035
2036 // Update name_map with newly-created ops.
2037 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2038 // a bad status. Once we fix this, we may want to return early instead of
2039 // executing the following code.
2040 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2041 Node* new_node = parent->graph.FindNodeId(i);
2042 if (new_node == nullptr) continue;
2043 parent->name_map[new_node->name()] = new_node;
2044 }
2045
2046 // Populate 'outputs'.
2047 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2048 for (int i = 0; i < loop_outputs.size(); ++i) {
2049 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2050 }
2051 }
2052
2053 } // namespace
2054 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2055
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2056 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2057 TF_Output* outputs) {
2058 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2059 status->status = tensorflow::errors::Unimplemented(
2060 "Creating while loops is not supported on mobile. File a bug at "
2061 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2062 "important to you");
2063 #else
2064 // If it appears the caller created or modified `params`, don't free resources
2065 if (!ValidateConstWhileParams(*params, status)) return;
2066 TF_FinishWhileHelper(params, status, outputs);
2067 FreeWhileResources(params);
2068 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2069 }
2070
TF_AbortWhile(const TF_WhileParams * params)2071 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2072
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2073 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2074 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2075 TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2076 }
2077
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2078 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2079 int ny, TF_Output* x, int nx, TF_Output* dx,
2080 TF_Status* status, TF_Output* dy) {
2081 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2082 status->status = tensorflow::errors::Unimplemented(
2083 "Adding gradients is not supported on mobile. File a bug at "
2084 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2085 "important to you");
2086 #else
2087 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2088 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2089 std::vector<tensorflow::Output> dy_arg;
2090
2091 {
2092 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2093 mutex_lock graph_lock(g->mu);
2094
2095 const int first_new_node_id = g->graph.num_node_ids();
2096
2097 string prefix_cmp;
2098 const char* child_scope_name;
2099 if (prefix == nullptr) {
2100 child_scope_name = "gradients";
2101 } else {
2102 prefix_cmp = string(prefix) + "/";
2103 // The operation should fail if the provided name prefix has already been
2104 // used in this graph
2105 for (const auto& pair : g->name_map) {
2106 const string& name = pair.first;
2107 if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2108 status->status = InvalidArgument(
2109 "prefix [", prefix,
2110 "] conflicts with existing node in the graph named [", name, "]");
2111 return;
2112 }
2113 }
2114 child_scope_name = prefix;
2115 }
2116 tensorflow::Scope scope =
2117 NewInternalScope(&g->graph, &status->status, &g->refiner)
2118 .NewSubScope(child_scope_name);
2119
2120 if (dx != nullptr) {
2121 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2122 status->status =
2123 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2124 } else {
2125 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2126 }
2127
2128 // Update g->name_map with the name_map from the scope, which will contain
2129 // the new gradient ops.
2130 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2131 Node* n = g->graph.FindNodeId(i);
2132 if (n == nullptr) continue;
2133
2134 // Adding the gradients to the graph can alter the prefix to prevent
2135 // name collisions only if this prefix has not been provided explicitly
2136 // by the user. If it was provided, assert that it remained intact.
2137 if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2138 status->status = tensorflow::errors::Internal(
2139 "BUG: The gradients prefix have been unexpectedly altered when "
2140 "adding the nodes to the graph. This is a bug. Please file an "
2141 "issue at https://github.com/tensorflow/tensorflow/issues.");
2142 return;
2143 }
2144 // We have a convoluted scheme here: Using the C++ graph construction API
2145 // to add potentially many nodes to the graph without running the checks
2146 // (such as uniqueness of the names of nodes) we run with other functions
2147 // that add a node to the graph (like TF_FinishOperation).
2148 if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2149 status->status = tensorflow::errors::Internal(
2150 "BUG: The API allowed construction of a graph with duplicate node "
2151 "names (",
2152 n->name(),
2153 "). This is a bug. Please file an issue at "
2154 "https://github.com/tensorflow/tensorflow/issues.");
2155 }
2156 }
2157 }
2158
2159 // Unpack the results from grad_outputs_arg.
2160 TFOutputsFromOutputs(dy_arg, dy);
2161 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2162 }
2163
2164 // TF_Session functions ----------------------------------------------
2165
TF_Session(tensorflow::Session * s,TF_Graph * g)2166 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2167 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2168
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2169 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2170 TF_Status* status) {
2171 Session* session;
2172 status->status = NewSession(opt->options, &session);
2173 if (status->status.ok()) {
2174 TF_Session* new_session = new TF_Session(session, graph);
2175 if (graph != nullptr) {
2176 mutex_lock l(graph->mu);
2177 graph->sessions[new_session] = "";
2178 }
2179 return new_session;
2180 } else {
2181 DCHECK_EQ(nullptr, session);
2182 return nullptr;
2183 }
2184 }
2185
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2186 TF_Session* TF_LoadSessionFromSavedModel(
2187 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2188 const char* export_dir, const char* const* tags, int tags_len,
2189 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2190 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2191 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2192 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2193 status->status = tensorflow::errors::Unimplemented(
2194 "Loading a SavedModel is not supported on mobile. File a bug at "
2195 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2196 "important to you");
2197 return nullptr;
2198 #else
2199 mutex_lock l(graph->mu);
2200 if (!graph->name_map.empty()) {
2201 status->status = InvalidArgument("Graph is non-empty.");
2202 return nullptr;
2203 }
2204
2205 RunOptions run_options_proto;
2206 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2207 run_options->data, run_options->length)) {
2208 status->status = InvalidArgument("Unparseable RunOptions proto");
2209 return nullptr;
2210 }
2211
2212 std::unordered_set<string> tag_set;
2213 for (int i = 0; i < tags_len; i++) {
2214 tag_set.insert(string(tags[i]));
2215 }
2216
2217 tensorflow::SavedModelBundle bundle;
2218 status->status =
2219 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2220 export_dir, tag_set, &bundle);
2221 if (!status->status.ok()) return nullptr;
2222
2223 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2224 // extends using GraphDefs. The Graph instance is different, but equivalent
2225 // to the one used to create the session.
2226 //
2227 // TODO(jhseu): When Session is modified to take Graphs instead of
2228 // GraphDefs, return the Graph generated in LoadSavedModel().
2229 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2230 TF_ImportGraphDefResults results;
2231 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2232 import_opts, &results, status);
2233 TF_DeleteImportGraphDefOptions(import_opts);
2234 if (!status->status.ok()) return nullptr;
2235
2236 if (meta_graph_def != nullptr) {
2237 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2238 if (!status->status.ok()) return nullptr;
2239 }
2240
2241 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2242
2243 graph->sessions[session] = "";
2244 session->last_num_graph_nodes = graph->graph.num_node_ids();
2245 return session;
2246 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2247 }
2248
TF_CloseSession(TF_Session * s,TF_Status * status)2249 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2250 status->status = s->session->Close();
2251 }
2252
TF_DeleteSession(TF_Session * s,TF_Status * status)2253 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2254 status->status = Status::OK();
2255 if (s == nullptr) return;
2256 TF_Graph* const graph = s->graph;
2257 if (graph != nullptr) {
2258 graph->mu.lock();
2259 graph->sessions.erase(s);
2260 const bool del = graph->delete_requested && graph->sessions.empty();
2261 graph->mu.unlock();
2262 if (del) delete graph;
2263 }
2264 delete s->session;
2265 delete s;
2266 }
2267
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2268 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2269 const TF_Output* inputs, TF_Tensor* const* input_values,
2270 int ninputs, const TF_Output* outputs,
2271 TF_Tensor** output_values, int noutputs,
2272 const TF_Operation* const* target_opers, int ntargets,
2273 TF_Buffer* run_metadata, TF_Status* status) {
2274 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2275 // directly, instead of requiring us to serialize to a GraphDef and
2276 // call Session::Extend().
2277 if (session->extend_before_run &&
2278 !ExtendSessionGraphHelper(session, status)) {
2279 return;
2280 }
2281
2282 TF_Run_Setup(noutputs, output_values, status);
2283
2284 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2285 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2286 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2287 for (int i = 0; i < ninputs; ++i) {
2288 input_pairs[i].first = OutputName(inputs[i]);
2289 }
2290
2291 // Convert from TF_Output to string names.
2292 std::vector<string> output_names(noutputs);
2293 for (int i = 0; i < noutputs; ++i) {
2294 output_names[i] = OutputName(outputs[i]);
2295 }
2296
2297 // Convert from TF_Operation* to string names.
2298 std::vector<string> target_names(ntargets);
2299 for (int i = 0; i < ntargets; ++i) {
2300 target_names[i] = target_opers[i]->node.name();
2301 }
2302
2303 // Actually run.
2304 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2305 output_names, output_values, target_names, run_metadata,
2306 status);
2307 }
2308
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2309 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2310 int ninputs, const TF_Output* outputs, int noutputs,
2311 const TF_Operation* const* target_opers, int ntargets,
2312 const char** handle, TF_Status* status) {
2313 *handle = nullptr;
2314
2315 if (session->extend_before_run &&
2316 !ExtendSessionGraphHelper(session, status)) {
2317 return;
2318 }
2319
2320 std::vector<string> input_names(ninputs);
2321 for (int i = 0; i < ninputs; ++i) {
2322 input_names[i] = OutputName(inputs[i]);
2323 }
2324
2325 std::vector<string> output_names(noutputs);
2326 for (int i = 0; i < noutputs; ++i) {
2327 output_names[i] = OutputName(outputs[i]);
2328 }
2329
2330 std::vector<string> target_names(ntargets);
2331 for (int i = 0; i < ntargets; ++i) {
2332 target_names[i] = target_opers[i]->node.name();
2333 }
2334
2335 string new_handle;
2336 status->status = session->session->PRunSetup(input_names, output_names,
2337 target_names, &new_handle);
2338 if (status->status.ok()) {
2339 char* buf = new char[new_handle.size() + 1];
2340 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2341 *handle = buf;
2342 }
2343 }
2344
TF_DeletePRunHandle(const char * handle)2345 void TF_DeletePRunHandle(const char* handle) {
2346 delete[] handle;
2347 // TODO(suharshs): Free up any resources held by the partial run state.
2348 }
2349
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2350 void TF_SessionPRun(TF_Session* session, const char* handle,
2351 const TF_Output* inputs, TF_Tensor* const* input_values,
2352 int ninputs, const TF_Output* outputs,
2353 TF_Tensor** output_values, int noutputs,
2354 const TF_Operation* const* target_opers, int ntargets,
2355 TF_Status* status) {
2356 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2357 // directly, instead of requiring us to serialize to a GraphDef and
2358 // call Session::Extend().
2359 if (session->extend_before_run &&
2360 !ExtendSessionGraphHelper(session, status)) {
2361 return;
2362 }
2363
2364 TF_Run_Setup(noutputs, output_values, status);
2365
2366 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2367 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2368 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2369 for (int i = 0; i < ninputs; ++i) {
2370 input_pairs[i].first = OutputName(inputs[i]);
2371 }
2372
2373 // Convert from TF_Output to string names.
2374 std::vector<string> output_names(noutputs);
2375 for (int i = 0; i < noutputs; ++i) {
2376 output_names[i] = OutputName(outputs[i]);
2377 }
2378
2379 // Convert from TF_Operation* to string names.
2380 std::vector<string> target_names(ntargets);
2381 for (int i = 0; i < ntargets; ++i) {
2382 target_names[i] = target_opers[i]->node.name();
2383 }
2384
2385 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2386 output_values, target_names, nullptr, status);
2387 }
2388
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2389 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2390 TF_Tensor** result, TF_Status* status) {
2391 *result = nullptr;
2392 mutex_lock l(graph->mu);
2393 OutputTensor tensor(&output.oper->node, output.index);
2394 bool evaluated;
2395 Tensor result_tensor;
2396 status->status = EvaluateConstantTensor(
2397 tensor, graph->refiner, *graph->graph.op_registry(),
2398 graph->graph.versions().producer(), &evaluated, &result_tensor);
2399 if (evaluated) {
2400 DCHECK(status->status.ok());
2401 *result = TF_TensorFromTensor(result_tensor, &status->status);
2402 if (!status->status.ok()) evaluated = false;
2403 }
2404 return evaluated;
2405 }
2406
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2407 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2408 tensorflow::OpList op_list;
2409 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2410 status->status = InvalidArgument("Unparseable OpList");
2411 return nullptr;
2412 }
2413 status->status = Status::OK();
2414 return new TF_ApiDefMap(op_list);
2415 }
2416
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2417 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2418
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2419 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2420 size_t text_len, TF_Status* status) {
2421 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2422 status->status = tensorflow::errors::Unimplemented(
2423 "ApiDefMap is not supported on mobile.");
2424 #else
2425 mutex_lock l(api_def_map->lock);
2426 if (api_def_map->update_docs_called) {
2427 status->status = FailedPrecondition(
2428 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2429 "called.");
2430 return;
2431 }
2432 string api_def_text(text, text_len);
2433 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2434 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2435 }
2436
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2437 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2438 size_t name_len, TF_Status* status) {
2439 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2440 status->status = tensorflow::errors::Unimplemented(
2441 "ApiDefMap is not supported on mobile.");
2442 return nullptr;
2443 #else
2444 mutex_lock l(api_def_map->lock);
2445 if (!api_def_map->update_docs_called) {
2446 api_def_map->api_def_map.UpdateDocs();
2447 api_def_map->update_docs_called = true;
2448 }
2449 string name_str(name, name_len);
2450 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2451 if (api_def == nullptr) {
2452 return nullptr;
2453 }
2454
2455 TF_Buffer* ret = TF_NewBuffer();
2456 status->status = MessageToBuffer(*api_def, ret);
2457 if (!status->status.ok()) {
2458 TF_DeleteBuffer(ret);
2459 return nullptr;
2460 }
2461 return ret;
2462 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2463 }
2464
TF_GetAllRegisteredKernels(TF_Status * status)2465 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2466 tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2467 TF_Buffer* ret = TF_NewBuffer();
2468 status->status = MessageToBuffer(kernel_list, ret);
2469 if (!status->status.ok()) {
2470 TF_DeleteBuffer(ret);
2471 return nullptr;
2472 }
2473 return ret;
2474 }
2475
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2476 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2477 tensorflow::KernelList kernel_list =
2478 tensorflow::GetRegisteredKernelsForOp(name);
2479 TF_Buffer* ret = TF_NewBuffer();
2480 status->status = MessageToBuffer(kernel_list, ret);
2481 if (!status->status.ok()) {
2482 TF_DeleteBuffer(ret);
2483 return nullptr;
2484 }
2485 return ret;
2486 }
2487
2488 // TF_Server functions ----------------------------------------------
2489
2490 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2491 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2492 : target(server->target()), server(std::move(server)) {}
2493 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2494
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2495 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2496 TF_Status* status) {
2497 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2498 status->status = tensorflow::errors::Unimplemented(
2499 "Server functionality is not supported on mobile");
2500 return nullptr;
2501 #else
2502 tensorflow::ServerDef server_def;
2503 if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2504 status->status = InvalidArgument(
2505 "Could not parse provided bytes into a ServerDef protocol buffer");
2506 return nullptr;
2507 }
2508
2509 std::unique_ptr<tensorflow::ServerInterface> out_server;
2510 status->status = tensorflow::NewServer(server_def, &out_server);
2511 if (!status->status.ok()) return nullptr;
2512
2513 return new TF_Server(std::move(out_server));
2514 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2515 }
2516
TF_ServerStart(TF_Server * server,TF_Status * status)2517 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2518 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2519 status->status = tensorflow::errors::Unimplemented(
2520 "Server functionality is not supported on mobile");
2521 #else
2522 status->status = server->server->Start();
2523 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2524 }
2525
TF_ServerStop(TF_Server * server,TF_Status * status)2526 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2527 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2528 status->status = tensorflow::errors::Unimplemented(
2529 "Server functionality is not supported on mobile");
2530 #else
2531 status->status = server->server->Stop();
2532 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2533 }
2534
TF_ServerJoin(TF_Server * server,TF_Status * status)2535 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2536 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2537 status->status = tensorflow::errors::Unimplemented(
2538 "Server functionality is not supported on mobile");
2539 #else
2540 status->status = server->server->Join();
2541 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2542 }
2543
TF_ServerTarget(TF_Server * server)2544 const char* TF_ServerTarget(TF_Server* server) {
2545 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2546 return nullptr;
2547 #else
2548 return server->target.c_str();
2549 #endif
2550 }
2551
TF_DeleteServer(TF_Server * server)2552 void TF_DeleteServer(TF_Server* server) {
2553 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2554 delete server;
2555 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2556 }
2557
TF_RegisterLogListener(void (* listener)(const char *))2558 void TF_RegisterLogListener(void (*listener)(const char*)) {
2559 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2560 tensorflow::logging::RegisterListener(listener);
2561 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2562 }
2563
2564 } // end extern "C"
2565