• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h"  // NOLINT
26 
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/cc/framework/gradients.h"
29 #include "tensorflow/cc/framework/ops.h"
30 #include "tensorflow/cc/framework/scope_internal.h"
31 #include "tensorflow/cc/ops/while_loop.h"
32 #include "tensorflow/cc/saved_model/loader.h"
33 #include "tensorflow/core/distributed_runtime/server_lib.h"
34 #include "tensorflow/core/framework/logging.h"
35 #include "tensorflow/core/framework/op_gen_lib.h"
36 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
37 #include "tensorflow/c/c_api_internal.h"
38 #include "tensorflow/c/tf_status_internal.h"
39 #include "tensorflow/c/tf_tensor.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
42 #include "tensorflow/core/common_runtime/shape_refiner.h"
43 #include "tensorflow/core/framework/allocation_description.pb.h"
44 #include "tensorflow/core/framework/kernel_def.pb.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/op_kernel.h"
48 #include "tensorflow/core/framework/partial_tensor_shape.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/framework/versions.pb.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/graph/graph_constructor.h"
57 #include "tensorflow/core/graph/node_builder.h"
58 #include "tensorflow/core/graph/validate.h"
59 #include "tensorflow/core/lib/core/coding.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/lib/core/status.h"
62 #include "tensorflow/core/lib/core/stringpiece.h"
63 #include "tensorflow/core/lib/gtl/array_slice.h"
64 #include "tensorflow/core/lib/strings/str_util.h"
65 #include "tensorflow/core/lib/strings/strcat.h"
66 #include "tensorflow/core/platform/mem.h"
67 #include "tensorflow/core/platform/mutex.h"
68 #include "tensorflow/core/platform/protobuf.h"
69 #include "tensorflow/core/platform/thread_annotations.h"
70 #include "tensorflow/core/platform/types.h"
71 #include "tensorflow/core/public/session.h"
72 #include "tensorflow/core/public/version.h"
73 
74 // The implementation below is at the top level instead of the
75 // brain namespace because we are defining 'extern "C"' functions.
76 using tensorflow::AllocationDescription;
77 using tensorflow::DataType;
78 using tensorflow::ExtendSessionGraphHelper;
79 using tensorflow::Graph;
80 using tensorflow::GraphDef;
81 using tensorflow::mutex_lock;
82 using tensorflow::NameRangeMap;
83 using tensorflow::NameRangesForNode;
84 using tensorflow::NewSession;
85 using tensorflow::Node;
86 using tensorflow::NodeBuilder;
87 using tensorflow::NodeDef;
88 using tensorflow::OpDef;
89 using tensorflow::OpRegistry;
90 using tensorflow::OutputTensor;
91 using tensorflow::PartialTensorShape;
92 using tensorflow::RunMetadata;
93 using tensorflow::RunOptions;
94 using tensorflow::Session;
95 using tensorflow::Status;
96 using tensorflow::string;
97 using tensorflow::Tensor;
98 using tensorflow::TensorBuffer;
99 using tensorflow::TensorId;
100 using tensorflow::TensorShape;
101 using tensorflow::TensorShapeProto;
102 using tensorflow::VersionDef;
103 using tensorflow::errors::FailedPrecondition;
104 using tensorflow::errors::InvalidArgument;
105 using tensorflow::gtl::ArraySlice;
106 using tensorflow::strings::StrCat;
107 
108 extern "C" {
109 
110 // --------------------------------------------------------------------------
TF_Version()111 const char* TF_Version() { return TF_VERSION_STRING; }
112 
113 // --------------------------------------------------------------------------
114 
115 // --------------------------------------------------------------------------
TF_NewSessionOptions()116 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)117 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
118 
TF_SetTarget(TF_SessionOptions * options,const char * target)119 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
120   options->options.target = target;
121 }
122 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)123 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
124                   size_t proto_len, TF_Status* status) {
125   if (!options->options.config.ParseFromArray(proto, proto_len)) {
126     status->status = InvalidArgument("Unparseable ConfigProto");
127   }
128 }
129 // --------------------------------------------------------------------------
TF_NewBuffer()130 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
131 
TF_NewBufferFromString(const void * proto,size_t proto_len)132 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
133   void* copy = tensorflow::port::Malloc(proto_len);
134   memcpy(copy, proto, proto_len);
135 
136   TF_Buffer* buf = new TF_Buffer;
137   buf->data = copy;
138   buf->length = proto_len;
139   buf->data_deallocator = [](void* data, size_t length) {
140     tensorflow::port::Free(data);
141   };
142   return buf;
143 }
144 
TF_DeleteBuffer(TF_Buffer * buffer)145 void TF_DeleteBuffer(TF_Buffer* buffer) {
146   if (buffer == nullptr) return;
147   if (buffer->data_deallocator != nullptr) {
148     (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
149                                 buffer->length);
150   }
151   delete buffer;
152 }
153 
TF_GetBuffer(TF_Buffer * buffer)154 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
155 
156 // --------------------------------------------------------------------------
157 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)158 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
159                                               TF_Status* status) {
160   Session* session;
161   status->status = NewSession(opt->options, &session);
162   if (status->status.ok()) {
163     return new TF_DeprecatedSession({session});
164   } else {
165     DCHECK_EQ(nullptr, session);
166     return nullptr;
167   }
168 }
169 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)170 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
171   status->status = s->session->Close();
172 }
173 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)174 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
175   status->status = Status::OK();
176   if (s == nullptr) return;
177   delete s->session;
178   delete s;
179 }
180 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)181 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
182                     size_t proto_len, TF_Status* status) {
183   GraphDef g;
184   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
185     status->status = InvalidArgument("Invalid GraphDef");
186     return;
187   }
188   status->status = s->session->Extend(g);
189 }
190 
191 }  // end extern "C"
192 
193 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)194 static void TF_Reset_Helper(const TF_SessionOptions* opt,
195                             const char** containers, int ncontainers,
196                             TF_Status* status) {
197   std::vector<string> container_names(ncontainers);
198   for (int i = 0; i < ncontainers; ++i) {
199     container_names[i] = containers[i];
200   }
201 
202   status->status = Reset(opt->options, container_names);
203 }
204 
205 extern "C" {
206 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)207 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
208               int ncontainers, TF_Status* status) {
209   TF_Reset_Helper(opt, containers, ncontainers, status);
210 }
211 
212 }  // end extern "C"
213 
214 namespace tensorflow {
215 
216 
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)217 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
218                        TF_Buffer* out) {
219   if (out->data != nullptr) {
220     return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
221   }
222   const size_t proto_size = in.ByteSizeLong();
223   void* buf = port::Malloc(proto_size);
224   if (buf == nullptr) {
225     return tensorflow::errors::ResourceExhausted(
226         "Failed to allocate memory to serialize message of type '",
227         in.GetTypeName(), "' and size ", proto_size);
228   }
229   if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
230     port::Free(buf);
231     return InvalidArgument("Unable to serialize ", in.GetTypeName(),
232                            " protocol buffer, perhaps the serialized size (",
233                            proto_size, " bytes) is too large?");
234   }
235   out->data = buf;
236   out->length = proto_size;
237   out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
238   return Status::OK();
239 }
240 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)241 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
242                     const char* mutation_type) {
243   // If any session has already run this node_id, mark this session as
244   // unrunnable.
245   for (auto it : graph->sessions) {
246     mutex_lock session_lock(it.first->mu);
247     if (it.first->last_num_graph_nodes > op.node.id()) {
248       it.second = strings::StrCat(
249           "Operation '", op.node.DebugString(), "' was changed by ",
250           mutation_type,
251           " after it was run by a session. This mutation will have no effect, "
252           "and will trigger an error in the future. Either don't modify "
253           "nodes after running them or create a new session.");
254     }
255   }
256 }
257 
258 namespace {
259 
260 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)261 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
262     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
263     const int64_t* dims) {
264   if (num_dims != -1) {
265     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
266     dim_vec.reserve(num_dims);
267     for (int i = 0; i < num_dims; ++i) {
268       dim_vec.push_back(ic->MakeDim(dims[i]));
269     }
270     return ic->MakeShape(dim_vec);
271   } else {
272     return ic->UnknownShape();
273   }
274 }
275 
276 }  // namespace
277 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)278 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
279                                            int num_shapes_and_types,
280                                            const int64_t** shapes,
281                                            const int* ranks,
282                                            const TF_DataType* types,
283                                            TF_Status* status) {
284   Node* node = &output.oper->node;
285 
286   mutex_lock l(graph->mu);
287   tensorflow::shape_inference::InferenceContext* ic =
288       graph->refiner.GetContext(node);
289   if (ic == nullptr) {
290     status->status =
291         InvalidArgument("Node ", node->name(), " was not found in the graph");
292     return;
293   }
294 
295   auto shape_and_type_vec =
296       std::vector<tensorflow::shape_inference::ShapeAndType>(
297           num_shapes_and_types);
298   for (int i = 0; i < num_shapes_and_types; ++i) {
299     tensorflow::shape_inference::ShapeHandle shape_handle =
300         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
301     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
302         shape_handle, static_cast<DataType>(types[i]));
303   }
304 
305   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
306 }
307 
308 // Helpers for loading a TensorFlow plugin (a .so file).
309 Status LoadLibrary(const char* library_filename, void** result,
310                    const void** buf, size_t* len);
311 
312 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
313 // directly, instead of requiring us to serialize to a GraphDef and
314 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)315 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
316   if (session->graph != nullptr) {
317     // Take the graph lock before the session lock to avoid deadlock. This is
318     // safe since session->graph does not change.
319     session->graph->mu.lock();
320     mutex_lock session_lock(session->mu);
321     const Graph& graph = session->graph->graph;
322 
323     const string& mutation_warning = session->graph->sessions[session];
324     if (!mutation_warning.empty()) {
325       // TODO(b/74949947): turn this back into an error status
326       LOG(WARNING) << mutation_warning;
327       session->graph->sessions[session].clear();
328     }
329 
330     const auto num_nodes = graph.num_node_ids();
331     if (session->last_num_graph_nodes < num_nodes) {
332       // TODO(nolivia): check this on a subset of the graph instead of all of
333       // it.
334       status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
335       if (!status->status.ok()) {
336         session->graph->mu.unlock();
337         return false;
338       }
339 
340       GraphDef graph_def;
341       *graph_def.mutable_versions() = graph.versions();
342       // Fill graph_def with nodes with ids in the range
343       // [session->last_num_graph_nodes, num_nodes), that is the nodes
344       // added since the last TF_SessionRun() call.
345       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
346         Node* const node = graph.FindNodeId(id);
347         if (node != nullptr && node->IsOp()) {
348           NodeDef* const node_def = graph_def.add_node();
349           *node_def = node->def();
350         }
351       }
352       *graph_def.mutable_library() = graph.flib_def().ToProto();
353       session->graph->mu.unlock();
354       status->status = session->session->Extend(std::move(graph_def));
355       if (!status->status.ok()) {
356         // Contract is we always delete input_values[i].
357         return false;
358       }
359       // Note: session->session is not modified if Extend() fails, so
360       // we only set last_num_graph_nodes if it succeeds.
361       session->last_num_graph_nodes = num_nodes;
362     } else {
363       session->graph->mu.unlock();
364     }
365   }
366   return true;
367 }
368 
369 }  // namespace tensorflow
370 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)371 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
372                          TF_Status* status) {
373   status->status = Status::OK();
374   for (int i = 0; i < noutputs; ++i) {
375     c_outputs[i] = nullptr;
376   }
377 }
378 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380                           std::vector<std::pair<string, Tensor>>* input_pairs,
381                           TF_Status* status) {
382   const int ninputs = input_pairs->size();
383   for (int i = 0; i < ninputs; ++i) {
384     status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
385     if (!status->status.ok()) return false;
386   }
387   return true;
388 }
389 
390 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
391 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)392 static TF_Tensor* EmptyTensor(TF_DataType dtype,
393                               const tensorflow::TensorShape& shape) {
394   static char empty;
395   tensorflow::int64 nelems = 1;
396   std::vector<tensorflow::int64> dims;
397   for (int i = 0; i < shape.dims(); ++i) {
398     dims.push_back(shape.dim_size(i));
399     nelems *= shape.dim_size(i);
400   }
401   CHECK_EQ(nelems, 0);
402   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
403                 "64-bit int types should match in size");
404   return TF_NewTensor(
405       dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406       reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410     Session* session, const char* handle, const TF_Buffer* run_options,
411     // Input tensors
412     const std::vector<std::pair<string, Tensor>>& input_pairs,
413     // Output tensors
414     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415     // Target nodes
416     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417     TF_Status* status) {
418   const int noutputs = output_tensor_names.size();
419   std::vector<Tensor> outputs(noutputs);
420   Status result;
421 
422   if (handle == nullptr) {
423     RunOptions run_options_proto;
424     if (run_options != nullptr && !run_options_proto.ParseFromArray(
425                                       run_options->data, run_options->length)) {
426       status->status = InvalidArgument("Unparseable RunOptions proto");
427       return;
428     }
429     if (run_metadata != nullptr && run_metadata->data != nullptr) {
430       status->status =
431           InvalidArgument("Passing non-empty run_metadata is invalid.");
432       return;
433     }
434 
435     RunMetadata run_metadata_proto;
436     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437                           target_oper_names, &outputs, &run_metadata_proto);
438 
439     // Serialize back to upstream client, who now owns the new buffer
440     if (run_metadata != nullptr) {
441       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442       if (!status->status.ok()) return;
443     }
444   } else {
445     // NOTE(zongheng): PRun does not support RunOptions yet.
446     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447   }
448   if (!result.ok()) {
449     status->status = result;
450     return;
451   }
452 
453   // Store results in c_outputs[]
454   for (int i = 0; i < noutputs; ++i) {
455     const Tensor& src = outputs[i];
456     if (!src.IsInitialized() || src.NumElements() == 0) {
457       c_outputs[i] =
458           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459       continue;
460     }
461     c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462     if (!status->status.ok()) return;
463   }
464 }
465 
466 extern "C" {
467 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469             // Input tensors
470             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471             // Output tensors
472             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473             // Target nodes
474             const char** c_target_oper_names, int ntargets,
475             TF_Buffer* run_metadata, TF_Status* status) {
476   TF_Run_Setup(noutputs, c_outputs, status);
477   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479   for (int i = 0; i < ninputs; ++i) {
480     input_pairs[i].first = c_input_names[i];
481   }
482   std::vector<string> output_names(noutputs);
483   for (int i = 0; i < noutputs; ++i) {
484     output_names[i] = c_output_names[i];
485   }
486   std::vector<string> target_oper_names(ntargets);
487   for (int i = 0; i < ntargets; ++i) {
488     target_oper_names[i] = c_target_oper_names[i];
489   }
490   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491                 c_outputs, target_oper_names, run_metadata, status);
492 }
493 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495                   // Input names
496                   const char** c_input_names, int ninputs,
497                   // Output names
498                   const char** c_output_names, int noutputs,
499                   // Target nodes
500                   const char** c_target_oper_names, int ntargets,
501                   const char** handle, TF_Status* status) {
502   *handle = nullptr;
503 
504   std::vector<string> input_names(ninputs);
505   std::vector<string> output_names(noutputs);
506   std::vector<string> target_oper_names(ntargets);
507   for (int i = 0; i < ninputs; ++i) {
508     input_names[i] = c_input_names[i];
509   }
510   for (int i = 0; i < noutputs; ++i) {
511     output_names[i] = c_output_names[i];
512   }
513   for (int i = 0; i < ntargets; ++i) {
514     target_oper_names[i] = c_target_oper_names[i];
515   }
516   string new_handle;
517   status->status = s->session->PRunSetup(input_names, output_names,
518                                          target_oper_names, &new_handle);
519   if (status->status.ok()) {
520     char* buf = new char[new_handle.size() + 1];
521     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522     *handle = buf;
523   }
524 }
525 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527              // Input tensors
528              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529              // Output tensors
530              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531              // Target nodes
532              const char** c_target_oper_names, int ntargets,
533              TF_Status* status) {
534   TF_Run_Setup(noutputs, c_outputs, status);
535   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537   for (int i = 0; i < ninputs; ++i) {
538     input_pairs[i].first = c_input_names[i];
539   }
540 
541   std::vector<string> output_names(noutputs);
542   for (int i = 0; i < noutputs; ++i) {
543     output_names[i] = c_output_names[i];
544   }
545   std::vector<string> target_oper_names(ntargets);
546   for (int i = 0; i < ntargets; ++i) {
547     target_oper_names[i] = c_target_oper_names[i];
548   }
549   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550                 c_outputs, target_oper_names, nullptr, status);
551 }
552 
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554   TF_Library* lib_handle = new TF_Library;
555   status->status = tensorflow::LoadLibrary(
556       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557       &lib_handle->op_list.length);
558   if (!status->status.ok()) {
559     delete lib_handle;
560     return nullptr;
561   }
562   return lib_handle;
563 }
564 
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566 
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568   if (lib_handle == nullptr) return;
569   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570   delete lib_handle;
571 }
572 
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574   std::vector<tensorflow::OpDef> op_defs;
575   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576   tensorflow::OpList op_list;
577   for (const auto& op : op_defs) {
578     *(op_list.add_op()) = op;
579   }
580   TF_Buffer* ret = TF_NewBuffer();
581   TF_CHECK_OK(MessageToBuffer(op_list, ret));
582   return ret;
583 }
584 
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587 
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589 
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591   TF_DeviceList* response = new TF_DeviceList;
592   status->status = session->session->ListDevices(&response->response);
593   return response;
594 }
595 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)596 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
597                                                TF_Status* status) {
598   TF_DeviceList* response = new TF_DeviceList;
599   status->status = session->session->ListDevices(&response->response);
600   return response;
601 }
602 
TF_DeviceListCount(const TF_DeviceList * list)603 int TF_DeviceListCount(const TF_DeviceList* list) {
604   return list->response.size();
605 }
606 
607 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
608   return_type method_name(const TF_DeviceList* list, const int index,     \
609                           TF_Status* status) {                            \
610     if (list == nullptr) {                                                \
611       status->status = InvalidArgument("list is null!");                  \
612       return err_val;                                                     \
613     }                                                                     \
614     if (index < 0 || index >= list->response.size()) {                    \
615       status->status = InvalidArgument("index out of bounds");            \
616       return err_val;                                                     \
617     }                                                                     \
618     status->status = Status::OK();                                        \
619     return list->response[index].accessor;                                \
620   }
621 
622 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
623 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
624                      nullptr);
625 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
626 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
627 
628 #undef TF_DEVICELIST_METHOD
629 
630 }  // end extern "C"
631 
632 // --------------------------------------------------------------------------
633 // New Graph and Session API
634 
635 // Helper functions -----------------------------------------------------------
636 
637 namespace {
638 
ToOperation(Node * node)639 TF_Operation* ToOperation(Node* node) {
640   return static_cast<TF_Operation*>(static_cast<void*>(node));
641 }
642 
OutputName(const TF_Output & output)643 string OutputName(const TF_Output& output) {
644   return StrCat(output.oper->node.name(), ":", output.index);
645 }
646 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)647 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
648                                           const char* attr_name,
649                                           TF_Status* status) {
650   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
651   if (attr == nullptr) {
652     status->status = InvalidArgument("Operation '", oper->node.name(),
653                                      "' has no attr named '", attr_name, "'.");
654   }
655   return attr;
656 }
657 
ToTensorId(const TF_Output & output)658 TensorId ToTensorId(const TF_Output& output) {
659   return TensorId(output.oper->node.name(), output.index);
660 }
661 
662 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)663 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
664                                                      int n) {
665   std::vector<tensorflow::Output> outputs(n);
666   for (int i = 0; i < n; ++i) {
667     outputs[i] =
668         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
669   }
670   return outputs;
671 }
672 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)673 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
674                           TF_Output* tf_outputs) {
675   for (int i = 0; i < outputs.size(); i++) {
676     tf_outputs[i].oper = ToOperation(outputs[i].node());
677     tf_outputs[i].index = outputs[i].index();
678   }
679 }
680 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
681 
682 }  // namespace
683 
684 // Shape functions -----------------------------------------------------------
685 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)686 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
687                             const int64_t* dims, const int num_dims,
688                             TF_Status* status) {
689   Node* node = &output.oper->node;
690 
691   mutex_lock l(graph->mu);
692   tensorflow::shape_inference::InferenceContext* ic =
693       graph->refiner.GetContext(node);
694   if (ic == nullptr) {
695     status->status =
696         InvalidArgument("Node ", node->name(), " was not found in the graph");
697     return;
698   }
699   tensorflow::shape_inference::ShapeHandle new_shape =
700       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
701   status->status = graph->refiner.SetShape(node, output.index, new_shape);
702 }
703 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)704 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
705                              TF_Status* status) {
706   Node* node = &output.oper->node;
707 
708   mutex_lock l(graph->mu);
709   tensorflow::shape_inference::InferenceContext* ic =
710       graph->refiner.GetContext(node);
711   if (ic == nullptr) {
712     status->status =
713         InvalidArgument("Node ", node->name(), " was not found in the graph");
714     return -1;
715   }
716 
717   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
718 
719   // Unknown rank means the number of dimensions is -1.
720   if (!ic->RankKnown(shape)) {
721     return -1;
722   }
723 
724   return ic->Rank(shape);
725 }
726 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)727 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
728                             int num_dims, TF_Status* status) {
729   Node* node = &output.oper->node;
730 
731   mutex_lock l(graph->mu);
732   tensorflow::shape_inference::InferenceContext* ic =
733       graph->refiner.GetContext(node);
734   if (ic == nullptr) {
735     status->status =
736         InvalidArgument("Node ", node->name(), " was not found in the graph");
737     return;
738   }
739 
740   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
741 
742   int rank = -1;
743   if (ic->RankKnown(shape)) {
744     rank = ic->Rank(shape);
745   }
746 
747   if (num_dims != rank) {
748     status->status = InvalidArgument("Expected rank is ", num_dims,
749                                      " but actual rank is ", rank);
750     return;
751   }
752 
753   if (num_dims == 0) {
754     // Output shape is a scalar.
755     return;
756   }
757 
758   // Rank is greater than 0, so fill in the values, if known, and
759   // -1 for unknown values.
760   for (int i = 0; i < num_dims; ++i) {
761     auto dim = ic->Dim(shape, i);
762     tensorflow::int64 value = -1;
763     if (ic->ValueKnown(dim)) {
764       value = ic->Value(dim);
765     }
766     dims[i] = value;
767   }
768 }
769 
770 // TF_OperationDescription functions ------------------------------------------
771 
772 extern "C" {
773 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)774 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
775                                                       const char* op_type,
776                                                       const char* oper_name)
777     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
778   return new TF_OperationDescription(graph, op_type, oper_name);
779 }
780 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)781 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
782                                          const char* oper_name) {
783   mutex_lock l(graph->mu);
784   return TF_NewOperationLocked(graph, op_type, oper_name);
785 }
786 
TF_SetDevice(TF_OperationDescription * desc,const char * device)787 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
788   desc->node_builder.Device(device);
789 }
790 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)791 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
792   desc->node_builder.Input(&input.oper->node, input.index);
793 }
794 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)795 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
796                      int num_inputs) {
797   std::vector<NodeBuilder::NodeOut> input_list;
798   input_list.reserve(num_inputs);
799   for (int i = 0; i < num_inputs; ++i) {
800     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
801   }
802   desc->node_builder.Input(input_list);
803 }
804 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)805 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
806   desc->node_builder.ControlInput(&input->node);
807 }
808 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)809 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
810   desc->colocation_constraints.emplace(
811       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
812 }
813 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)814 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
815                       const void* value, size_t length) {
816   tensorflow::StringPiece s(static_cast<const char*>(value), length);
817   desc->node_builder.Attr(attr_name, s);
818 }
819 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)820 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
821                           const void* const* values, const size_t* lengths,
822                           int num_values) {
823   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
824     desc->colocation_constraints.clear();
825     for (int i = 0; i < num_values; ++i) {
826       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
827                                            lengths[i]);
828     }
829   } else {
830     std::vector<tensorflow::StringPiece> v;
831     v.reserve(num_values);
832     for (int i = 0; i < num_values; ++i) {
833       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
834     }
835     desc->node_builder.Attr(attr_name, v);
836   }
837 }
838 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)839 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
840                    int64_t value) {
841   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
842                 "64-bit int types should match in size");
843   desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
844 }
845 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)846 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
847                        const int64_t* values, int num_values) {
848   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
849                 "64-bit int types should match in size");
850   desc->node_builder.Attr(
851       attr_name,
852       ArraySlice<const tensorflow::int64>(
853           reinterpret_cast<const tensorflow::int64*>(values), num_values));
854 }
855 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)856 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
857                      float value) {
858   desc->node_builder.Attr(attr_name, value);
859 }
860 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)861 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
862                          const float* values, int num_values) {
863   desc->node_builder.Attr(attr_name,
864                           ArraySlice<const float>(values, num_values));
865 }
866 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)867 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
868                     unsigned char value) {
869   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
870 }
871 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)872 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
873                         const unsigned char* values, int num_values) {
874   std::unique_ptr<bool[]> b(new bool[num_values]);
875   for (int i = 0; i < num_values; ++i) {
876     b[i] = values[i];
877   }
878   desc->node_builder.Attr(attr_name,
879                           ArraySlice<const bool>(b.get(), num_values));
880 }
881 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)882 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
883                     TF_DataType value) {
884   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
885 }
886 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)887 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
888                         const TF_DataType* values, int num_values) {
889   desc->node_builder.Attr(
890       attr_name, ArraySlice<const DataType>(
891                      reinterpret_cast<const DataType*>(values), num_values));
892 }
893 
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)894 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
895                            const char* placeholder) {
896   tensorflow::AttrValue attr_value;
897   attr_value.set_placeholder(placeholder);
898   desc->node_builder.Attr(attr_name, attr_value);
899 }
900 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)901 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
902                         const char* value, size_t length) {
903   tensorflow::NameAttrList func_name;
904   func_name.set_name(string(value, value + length));
905   desc->node_builder.Attr(attr_name, func_name);
906 }
907 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)908 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
909                      const int64_t* dims, int num_dims) {
910   PartialTensorShape shape;
911   if (num_dims >= 0) {
912     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
913                   "64-bit int types should match in size");
914     shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
915         reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
916   }
917   desc->node_builder.Attr(attr_name, shape);
918 }
919 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)920 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
921                          const int64_t* const* dims, const int* num_dims,
922                          int num_shapes) {
923   std::vector<PartialTensorShape> shapes;
924   shapes.reserve(num_shapes);
925   for (int i = 0; i < num_shapes; ++i) {
926     if (num_dims[i] < 0) {
927       shapes.emplace_back();
928     } else {
929       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
930                     "64-bit int types should match in size");
931       shapes.emplace_back(ArraySlice<tensorflow::int64>(
932           reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
933     }
934   }
935   desc->node_builder.Attr(attr_name, shapes);
936 }
937 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)938 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
939                                 const char* attr_name, const void* proto,
940                                 size_t proto_len, TF_Status* status) {
941   // shape.ParseFromArray takes an int as length, this function takes size_t,
942   // make sure there is no information loss.
943   if (proto_len > std::numeric_limits<int>::max()) {
944     status->status = InvalidArgument(
945         "proto_len (", proto_len,
946         " bytes) is too large to be parsed by the protocol buffer library");
947     return;
948   }
949   TensorShapeProto shape;
950   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
951     desc->node_builder.Attr(attr_name, shape);
952     status->status = Status::OK();
953   } else {
954     status->status = InvalidArgument("Unparseable TensorShapeProto");
955   }
956 }
957 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)958 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
959                                     const char* attr_name,
960                                     const void* const* protos,
961                                     const size_t* proto_lens, int num_shapes,
962                                     TF_Status* status) {
963   std::vector<TensorShapeProto> shapes;
964   shapes.resize(num_shapes);
965   for (int i = 0; i < num_shapes; ++i) {
966     if (proto_lens[i] > std::numeric_limits<int>::max()) {
967       status->status = InvalidArgument(
968           "length of element ", i, " in the list (", proto_lens[i],
969           " bytes) is too large to be parsed by the protocol buffer library");
970       return;
971     }
972     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
973       status->status =
974           InvalidArgument("Unparseable TensorShapeProto at index ", i);
975       return;
976     }
977   }
978   desc->node_builder.Attr(attr_name, shapes);
979   status->status = Status::OK();
980 }
981 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)982 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
983                       TF_Tensor* value, TF_Status* status) {
984   Tensor t;
985   status->status = TF_TensorToTensor(value, &t);
986   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
987 }
988 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)989 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
990                           TF_Tensor* const* values, int num_values,
991                           TF_Status* status) {
992   status->status = Status::OK();
993   std::vector<Tensor> t;
994   t.reserve(num_values);
995 
996   for (int i = 0; i < num_values && status->status.ok(); ++i) {
997     Tensor v;
998     status->status = TF_TensorToTensor(values[i], &v);
999     t.emplace_back(v);
1000   }
1001 
1002   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1003 }
1004 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1005 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1006                           const void* proto, size_t proto_len,
1007                           TF_Status* status) {
1008   tensorflow::AttrValue attr_value;
1009   if (!attr_value.ParseFromArray(proto, proto_len)) {
1010     status->status = InvalidArgument("Unparseable AttrValue proto");
1011     return;
1012   }
1013 
1014   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1015     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1016         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1017       status->status =
1018           InvalidArgument("Expected \"list\" field for \"",
1019                           tensorflow::kColocationAttrName, "\" attribute");
1020       return;
1021     }
1022     desc->colocation_constraints.clear();
1023     for (const string& location : attr_value.list().s()) {
1024       desc->colocation_constraints.insert(location);
1025     }
1026   } else {
1027     desc->node_builder.Attr(attr_name, std::move(attr_value));
1028   }
1029 
1030   status->status = Status::OK();
1031 }
1032 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1033 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1034                                               TF_Status* status)
1035     EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1036   Node* ret = nullptr;
1037 
1038   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1039     status->status = InvalidArgument("Duplicate node name in graph: '",
1040                                      desc->node_builder.node_name(), "'");
1041   } else {
1042     if (!desc->colocation_constraints.empty()) {
1043       desc->node_builder.Attr(
1044           tensorflow::kColocationAttrName,
1045           std::vector<string>(desc->colocation_constraints.begin(),
1046                               desc->colocation_constraints.end()));
1047     }
1048     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1049                                                  /*consume=*/true);
1050 
1051     if (status->status.ok()) {
1052       // Run shape inference function for newly added node.
1053       status->status = desc->graph->refiner.AddNode(ret);
1054     }
1055     if (status->status.ok()) {
1056       // Add the node to the name-to-node mapping.
1057       desc->graph->name_map[ret->name()] = ret;
1058     } else if (ret != nullptr) {
1059       desc->graph->graph.RemoveNode(ret);
1060       ret = nullptr;
1061     }
1062   }
1063 
1064   delete desc;
1065 
1066   return ToOperation(ret);
1067 }
1068 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1069 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1070                                  TF_Status* status) {
1071   mutex_lock l(desc->graph->mu);
1072   return TF_FinishOperationLocked(desc, status);
1073 }
1074 
1075 // TF_Operation functions
1076 // ----------------------------------------------------------
1077 
TF_OperationName(TF_Operation * oper)1078 const char* TF_OperationName(TF_Operation* oper) {
1079   return oper->node.name().c_str();
1080 }
1081 
TF_OperationOpType(TF_Operation * oper)1082 const char* TF_OperationOpType(TF_Operation* oper) {
1083   return oper->node.type_string().c_str();
1084 }
1085 
TF_OperationDevice(TF_Operation * oper)1086 const char* TF_OperationDevice(TF_Operation* oper) {
1087   return oper->node.requested_device().c_str();
1088 }
1089 
TF_OperationNumOutputs(TF_Operation * oper)1090 int TF_OperationNumOutputs(TF_Operation* oper) {
1091   return oper->node.num_outputs();
1092 }
1093 
TF_OperationOutputType(TF_Output oper_out)1094 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1095   return static_cast<TF_DataType>(
1096       oper_out.oper->node.output_type(oper_out.index));
1097 }
1098 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1099 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1100                                  TF_Status* status) {
1101   NameRangeMap name_ranges;
1102   status->status =
1103       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1104   if (!status->status.ok()) return -1;
1105   auto iter = name_ranges.find(arg_name);
1106   if (iter == name_ranges.end()) {
1107     status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1108     return -1;
1109   }
1110   return iter->second.second - iter->second.first;
1111 }
1112 
TF_OperationNumInputs(TF_Operation * oper)1113 int TF_OperationNumInputs(TF_Operation* oper) {
1114   return oper->node.num_inputs();
1115 }
1116 
TF_OperationInputType(TF_Input oper_in)1117 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1118   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1119 }
1120 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1121 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1122                                 TF_Status* status) {
1123   NameRangeMap name_ranges;
1124   status->status =
1125       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1126   if (!status->status.ok()) return -1;
1127   auto iter = name_ranges.find(arg_name);
1128   if (iter == name_ranges.end()) {
1129     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1130     return -1;
1131   }
1132   return iter->second.second - iter->second.first;
1133 }
1134 
TF_OperationInput(TF_Input oper_in)1135 TF_Output TF_OperationInput(TF_Input oper_in) {
1136   const tensorflow::Edge* edge;
1137   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1138   if (!s.ok()) {
1139     return {nullptr, -1};
1140   }
1141 
1142   return {ToOperation(edge->src()), edge->src_output()};
1143 }
1144 
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1145 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1146                            int max_inputs) {
1147   for (auto* edge : oper->node.in_edges()) {
1148     if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1149       inputs[edge->dst_input()] = {ToOperation(edge->src()),
1150                                    edge->src_output()};
1151     }
1152   }
1153 }
1154 
TF_OperationOutputNumConsumers(TF_Output oper_out)1155 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1156   int count = 0;
1157   for (const auto* edge : oper_out.oper->node.out_edges()) {
1158     if (edge->src_output() == oper_out.index) {
1159       ++count;
1160     }
1161   }
1162   return count;
1163 }
1164 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1165 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1166                                 int max_consumers) {
1167   int count = 0;
1168   for (const auto* edge : oper_out.oper->node.out_edges()) {
1169     if (edge->src_output() == oper_out.index) {
1170       if (count < max_consumers) {
1171         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1172       }
1173       ++count;
1174     }
1175   }
1176   return count;
1177 }
1178 
TF_OperationNumControlInputs(TF_Operation * oper)1179 int TF_OperationNumControlInputs(TF_Operation* oper) {
1180   int count = 0;
1181   for (const auto* edge : oper->node.in_edges()) {
1182     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1183       ++count;
1184     }
1185   }
1186   return count;
1187 }
1188 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1189 int TF_OperationGetControlInputs(TF_Operation* oper,
1190                                  TF_Operation** control_inputs,
1191                                  int max_control_inputs) {
1192   int count = 0;
1193   for (const auto* edge : oper->node.in_edges()) {
1194     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1195       if (count < max_control_inputs) {
1196         control_inputs[count] = ToOperation(edge->src());
1197       }
1198       ++count;
1199     }
1200   }
1201   return count;
1202 }
1203 
TF_OperationNumControlOutputs(TF_Operation * oper)1204 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1205   int count = 0;
1206   for (const auto* edge : oper->node.out_edges()) {
1207     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1208       ++count;
1209     }
1210   }
1211   return count;
1212 }
1213 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1214 int TF_OperationGetControlOutputs(TF_Operation* oper,
1215                                   TF_Operation** control_outputs,
1216                                   int max_control_outputs) {
1217   int count = 0;
1218   for (const auto* edge : oper->node.out_edges()) {
1219     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1220       if (count < max_control_outputs) {
1221         control_outputs[count] = ToOperation(edge->dst());
1222       }
1223       ++count;
1224     }
1225   }
1226   return count;
1227 }
1228 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1229 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1230                                             const char* attr_name,
1231                                             TF_Status* status) {
1232   TF_AttrMetadata metadata;
1233   const auto* attr = GetAttrValue(oper, attr_name, status);
1234   if (!status->status.ok()) return metadata;
1235   switch (attr->value_case()) {
1236 #define SINGLE_CASE(kK, attr_type, size_expr) \
1237   case tensorflow::AttrValue::kK:             \
1238     metadata.is_list = 0;                     \
1239     metadata.list_size = -1;                  \
1240     metadata.type = attr_type;                \
1241     metadata.total_size = size_expr;          \
1242     break;
1243 
1244     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1245     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1246     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1247     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1248     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1249     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1250                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1251     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1252 #undef SINGLE_CASE
1253 
1254     case tensorflow::AttrValue::kList:
1255       metadata.is_list = 1;
1256       metadata.list_size = 0;
1257       metadata.total_size = -1;
1258 #define LIST_CASE(field, attr_type, ...)              \
1259   if (attr->list().field##_size() > 0) {              \
1260     metadata.type = attr_type;                        \
1261     metadata.list_size = attr->list().field##_size(); \
1262     __VA_ARGS__;                                      \
1263     break;                                            \
1264   }
1265 
1266       LIST_CASE(
1267           s, TF_ATTR_STRING, metadata.total_size = 0;
1268           for (int i = 0; i < attr->list().s_size();
1269                ++i) { metadata.total_size += attr->list().s(i).size(); });
1270       LIST_CASE(i, TF_ATTR_INT);
1271       LIST_CASE(f, TF_ATTR_FLOAT);
1272       LIST_CASE(b, TF_ATTR_BOOL);
1273       LIST_CASE(type, TF_ATTR_TYPE);
1274       LIST_CASE(
1275           shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1276           for (int i = 0; i < attr->list().shape_size(); ++i) {
1277             const auto& s = attr->list().shape(i);
1278             metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1279           });
1280       LIST_CASE(tensor, TF_ATTR_TENSOR);
1281       LIST_CASE(tensor, TF_ATTR_FUNC);
1282 #undef LIST_CASE
1283       // All lists empty, determine the type from the OpDef.
1284       if (metadata.list_size == 0) {
1285         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1286           const auto& a = oper->node.op_def().attr(i);
1287           if (a.name() != attr_name) continue;
1288           const string& typestr = a.type();
1289           if (typestr == "list(string)") {
1290             metadata.type = TF_ATTR_STRING;
1291           } else if (typestr == "list(int)") {
1292             metadata.type = TF_ATTR_INT;
1293           } else if (typestr == "list(float)") {
1294             metadata.type = TF_ATTR_FLOAT;
1295           } else if (typestr == "list(bool)") {
1296             metadata.type = TF_ATTR_BOOL;
1297           } else if (typestr == "list(type)") {
1298             metadata.type = TF_ATTR_TYPE;
1299           } else if (typestr == "list(shape)") {
1300             metadata.type = TF_ATTR_SHAPE;
1301           } else if (typestr == "list(tensor)") {
1302             metadata.type = TF_ATTR_TENSOR;
1303           } else if (typestr == "list(func)") {
1304             metadata.type = TF_ATTR_FUNC;
1305           } else {
1306             status->status = InvalidArgument(
1307                 "Attribute '", attr_name,
1308                 "' has an empty value of an unrecognized type '", typestr, "'");
1309             return metadata;
1310           }
1311         }
1312       }
1313       break;
1314 
1315     case tensorflow::AttrValue::kPlaceholder:
1316       metadata.is_list = 0;
1317       metadata.list_size = -1;
1318       metadata.type = TF_ATTR_PLACEHOLDER;
1319       metadata.total_size = -1;
1320       break;
1321 
1322     case tensorflow::AttrValue::kFunc:
1323       metadata.is_list = 0;
1324       metadata.list_size = -1;
1325       metadata.type = TF_ATTR_FUNC;
1326       metadata.total_size = -1;
1327       break;
1328 
1329     case tensorflow::AttrValue::VALUE_NOT_SET:
1330       status->status =
1331           InvalidArgument("Attribute '", attr_name, "' has no value set");
1332       break;
1333   }
1334   return metadata;
1335 }
1336 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1337 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1338                                void* value, size_t max_length,
1339                                TF_Status* status) {
1340   const auto* attr = GetAttrValue(oper, attr_name, status);
1341   if (!status->status.ok()) return;
1342   if (attr->value_case() != tensorflow::AttrValue::kS) {
1343     status->status =
1344         InvalidArgument("Attribute '", attr_name, "' is not a string");
1345     return;
1346   }
1347   if (max_length <= 0) {
1348     return;
1349   }
1350   const auto& s = attr->s();
1351   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1352 }
1353 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1354 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1355                                    void** values, size_t* lengths,
1356                                    int max_values, void* storage,
1357                                    size_t storage_size, TF_Status* status) {
1358   const auto* attr = GetAttrValue(oper, attr_name, status);
1359   if (!status->status.ok()) return;
1360   if (attr->value_case() != tensorflow::AttrValue::kList) {
1361     status->status =
1362         InvalidArgument("Value for '", attr_name, "' is not a list");
1363     return;
1364   }
1365   const auto len = std::min(max_values, attr->list().s_size());
1366   char* p = static_cast<char*>(storage);
1367   for (int i = 0; i < len; ++i) {
1368     const string& s = attr->list().s(i);
1369     values[i] = p;
1370     lengths[i] = s.size();
1371     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1372       status->status = InvalidArgument(
1373           "Not enough storage to hold the requested list of strings");
1374       return;
1375     }
1376     memcpy(values[i], s.data(), s.size());
1377     p += s.size();
1378   }
1379 }
1380 
1381 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1382   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1383             TF_Status* status) {                                             \
1384     cpp_type v;                                                              \
1385     status->status =                                                         \
1386         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1387     *value = static_cast<c_type>(v);                                         \
1388   }                                                                          \
1389   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1390                   int max_values, TF_Status* status) {                       \
1391     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1392     if (!status->status.ok()) return;                                        \
1393     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1394       status->status =                                                       \
1395           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1396       return;                                                                \
1397     }                                                                        \
1398     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1399     for (int i = 0; i < len; ++i) {                                          \
1400       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1401     }                                                                        \
1402   }
1403 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1404 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1405 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1406 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1407 #undef DEFINE_GETATTR
1408 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1409 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1410                               int64_t* value, int num_dims, TF_Status* status) {
1411   PartialTensorShape shape;
1412   status->status =
1413       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1414   if (!status->status.ok()) return;
1415   auto len = std::min(shape.dims(), num_dims);
1416   for (int i = 0; i < len; ++i) {
1417     value[i] = shape.dim_size(i);
1418   }
1419 }
1420 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1421 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1422                                   int64_t** dims, int* num_dims, int num_shapes,
1423                                   int64_t* storage, int storage_size,
1424                                   TF_Status* status) {
1425   std::vector<PartialTensorShape> shapes;
1426   status->status =
1427       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1428   if (!status->status.ok()) return;
1429   auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1430   int64_t* p = storage;
1431   int storage_left = storage_size;
1432   for (int i = 0; i < len; ++i) {
1433     // shapes[i].dims() == -1 for shapes with an unknown rank.
1434     int64_t n = shapes[i].dims();
1435     num_dims[i] = n;
1436     dims[i] = p;
1437     if (n < 0) {
1438       continue;
1439     }
1440     if (storage_left < n) {
1441       status->status = InvalidArgument(
1442           "Not enough storage to hold the requested list of shapes");
1443       return;
1444     }
1445     storage_left -= n;
1446     for (int j = 0; j < n; ++j, ++p) {
1447       *p = shapes[i].dim_size(j);
1448     }
1449   }
1450 }
1451 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1452 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1453                                          const char* attr_name,
1454                                          TF_Buffer* value, TF_Status* status) {
1455   const auto* attr = GetAttrValue(oper, attr_name, status);
1456   if (!status->status.ok()) return;
1457   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1458     status->status =
1459         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1460     return;
1461   }
1462   status->status = MessageToBuffer(attr->shape(), value);
1463 }
1464 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1465 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1466                                              const char* attr_name,
1467                                              TF_Buffer** values, int max_values,
1468                                              TF_Status* status) {
1469   const auto* attr = GetAttrValue(oper, attr_name, status);
1470   if (!status->status.ok()) return;
1471   if (attr->value_case() != tensorflow::AttrValue::kList) {
1472     status->status =
1473         InvalidArgument("Value for '", attr_name, "' is not a list");
1474     return;
1475   }
1476   const auto len = std::min(max_values, attr->list().shape_size());
1477   for (int i = 0; i < len; ++i) {
1478     values[i] = TF_NewBuffer();
1479     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1480     if (!status->status.ok()) {
1481       // Delete everything allocated to far, the operation has failed.
1482       for (int j = 0; j <= i; ++j) {
1483         TF_DeleteBuffer(values[j]);
1484       }
1485       return;
1486     }
1487   }
1488 }
1489 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1490 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1491                                TF_Tensor** value, TF_Status* status) {
1492   *value = nullptr;
1493   Tensor t;
1494   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1495   if (!status->status.ok()) return;
1496   *value = TF_TensorFromTensor(t, &status->status);
1497 }
1498 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1499 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1500                                    TF_Tensor** values, int max_values,
1501                                    TF_Status* status) {
1502   std::vector<Tensor> ts;
1503   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1504   if (!status->status.ok()) return;
1505   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1506   for (int i = 0; i < len; ++i) {
1507     values[i] = TF_TensorFromTensor(ts[i], &status->status);
1508   }
1509 }
1510 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1511 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1512                                    TF_Buffer* output_attr_value,
1513                                    TF_Status* status) {
1514   const auto* attr = GetAttrValue(oper, attr_name, status);
1515   if (!status->status.ok()) return;
1516   status->status = MessageToBuffer(*attr, output_attr_value);
1517 }
1518 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1519 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1520                            TF_Status* status) {
1521   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1522 }
1523 
1524 // TF_Graph functions ---------------------------------------------------------
1525 
TF_Graph()1526 TF_Graph::TF_Graph()
1527     : graph(tensorflow::OpRegistry::Global()),
1528       refiner(graph.versions().producer(), graph.op_registry()),
1529       delete_requested(false),
1530       parent(nullptr),
1531       parent_inputs(nullptr) {
1532   // Tell the shape refiner to also run shape inference on functions.
1533   refiner.set_function_library_for_shape_inference(&graph.flib_def());
1534 }
1535 
TF_NewGraph()1536 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1537 
TF_DeleteGraph(TF_Graph * g)1538 void TF_DeleteGraph(TF_Graph* g) {
1539   if (g == nullptr) return;
1540   g->mu.lock();
1541   g->delete_requested = true;
1542   const bool del = g->sessions.empty();
1543   g->mu.unlock();
1544   if (del) delete g;
1545 }
1546 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1547 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1548   mutex_lock l(graph->mu);
1549   auto iter = graph->name_map.find(oper_name);
1550   if (iter == graph->name_map.end()) {
1551     return nullptr;
1552   } else {
1553     return ToOperation(iter->second);
1554   }
1555 }
1556 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1557 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1558   if (*pos == 0) {
1559     // Advance past the first sentinel nodes in every graph (the source & sink).
1560     *pos += 2;
1561   } else {
1562     // Advance to the next node.
1563     *pos += 1;
1564   }
1565 
1566   mutex_lock l(graph->mu);
1567   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1568     Node* node = graph->graph.FindNodeId(*pos);
1569     // FindNodeId() returns nullptr for nodes that have been deleted.
1570     // We aren't currently allowing nodes to be deleted, but it is safer
1571     // to still check.
1572     if (node != nullptr) return ToOperation(node);
1573     *pos += 1;
1574   }
1575 
1576   // No more nodes.
1577   return nullptr;
1578 }
1579 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1580 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1581                         TF_Status* status) {
1582   GraphDef def;
1583   {
1584     mutex_lock l(graph->mu);
1585     graph->graph.ToGraphDef(&def);
1586   }
1587   status->status = MessageToBuffer(def, output_graph_def);
1588 }
1589 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1590 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1591                       TF_Buffer* output_op_def, TF_Status* status) {
1592   const OpDef* op_def;
1593   {
1594     mutex_lock l(graph->mu);
1595     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1596     if (!status->status.ok()) return;
1597   }
1598   status->status = MessageToBuffer(*op_def, output_op_def);
1599 }
1600 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1601 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1602                       TF_Status* status) {
1603   VersionDef versions;
1604   {
1605     mutex_lock l(graph->mu);
1606     versions = graph->graph.versions();
1607   }
1608   status->status = MessageToBuffer(versions, output_version_def);
1609 }
1610 
TF_NewImportGraphDefOptions()1611 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1612   return new TF_ImportGraphDefOptions;
1613 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1614 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1615   delete opts;
1616 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1617 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1618                                        const char* prefix) {
1619   opts->opts.prefix = prefix;
1620 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1621 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1622                                               const char* device) {
1623   opts->opts.default_device = device;
1624 }
1625 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1626 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1627                                               unsigned char uniquify_names) {
1628   opts->opts.uniquify_names = uniquify_names;
1629 }
1630 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1631 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1632                                                unsigned char uniquify_prefix) {
1633   opts->opts.uniquify_prefix = uniquify_prefix;
1634 }
1635 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1636 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1637                                              const char* src_name,
1638                                              int src_index, TF_Output dst) {
1639   opts->tensor_id_data.push_back(src_name);
1640   const string& src_name_str = opts->tensor_id_data.back();
1641   // We don't need to store dst's name in tensor_id_data, since `dst` must
1642   // outlive the ImportGraphDef call.
1643   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1644 }
1645 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1646 void TF_ImportGraphDefOptionsRemapControlDependency(
1647     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1648   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1649       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1650 }
1651 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1652 extern void TF_ImportGraphDefOptionsAddControlDependency(
1653     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1654   opts->opts.control_dependencies.push_back(oper->node.name());
1655 }
1656 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1657 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1658                                              const char* oper_name, int index) {
1659   opts->tensor_id_data.push_back(oper_name);
1660   const string& oper_name_str = opts->tensor_id_data.back();
1661   opts->opts.return_tensors.emplace_back(oper_name_str, index);
1662 }
1663 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1664 int TF_ImportGraphDefOptionsNumReturnOutputs(
1665     const TF_ImportGraphDefOptions* opts) {
1666   return opts->opts.return_tensors.size();
1667 }
1668 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1669 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1670                                                 const char* oper_name) {
1671   opts->opts.return_nodes.push_back(oper_name);
1672 }
1673 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1674 int TF_ImportGraphDefOptionsNumReturnOperations(
1675     const TF_ImportGraphDefOptions* opts) {
1676   return opts->opts.return_nodes.size();
1677 }
1678 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1679 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1680                                            int* num_outputs,
1681                                            TF_Output** outputs) {
1682   *num_outputs = results->return_tensors.size();
1683   *outputs = results->return_tensors.data();
1684 }
1685 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1686 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1687                                               int* num_opers,
1688                                               TF_Operation*** opers) {
1689   *num_opers = results->return_nodes.size();
1690   *opers = results->return_nodes.data();
1691 }
1692 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1693 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1694     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1695     const char*** src_names, int** src_indexes) {
1696   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1697   *src_names = results->missing_unused_key_names.data();
1698   *src_indexes = results->missing_unused_key_indexes.data();
1699 }
1700 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1701 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1702   delete results;
1703 }
1704 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1705 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1706                                       const TF_ImportGraphDefOptions* opts,
1707                                       TF_ImportGraphDefResults* tf_results,
1708                                       TF_Status* status)
1709     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1710   const int last_node_id = graph->graph.num_node_ids();
1711   tensorflow::ImportGraphDefResults results;
1712   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1713                                               &graph->refiner, &results);
1714   if (!status->status.ok()) return;
1715 
1716   // Add new nodes to name_map
1717   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1718     auto* node = graph->graph.FindNodeId(i);
1719     if (node != nullptr) graph->name_map[node->name()] = node;
1720   }
1721 
1722   // Populate return_tensors
1723   DCHECK(tf_results->return_tensors.empty());
1724   tf_results->return_tensors.resize(results.return_tensors.size());
1725   for (int i = 0; i < results.return_tensors.size(); ++i) {
1726     tf_results->return_tensors[i].oper =
1727         ToOperation(results.return_tensors[i].first);
1728     tf_results->return_tensors[i].index = results.return_tensors[i].second;
1729   }
1730 
1731   // Populate return_nodes
1732   DCHECK(tf_results->return_nodes.empty());
1733   tf_results->return_nodes.resize(results.return_nodes.size());
1734   for (int i = 0; i < results.return_nodes.size(); ++i) {
1735     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1736   }
1737 
1738   // Populate missing unused map keys
1739   DCHECK(tf_results->missing_unused_key_names.empty());
1740   DCHECK(tf_results->missing_unused_key_indexes.empty());
1741   DCHECK(tf_results->missing_unused_key_names_data.empty());
1742 
1743   size_t size = results.missing_unused_input_map_keys.size();
1744   tf_results->missing_unused_key_names.resize(size);
1745   tf_results->missing_unused_key_indexes.resize(size);
1746 
1747   for (int i = 0; i < size; ++i) {
1748     TensorId id = results.missing_unused_input_map_keys[i];
1749     tf_results->missing_unused_key_names_data.emplace_back(id.first);
1750     tf_results->missing_unused_key_names[i] =
1751         tf_results->missing_unused_key_names_data.back().c_str();
1752     tf_results->missing_unused_key_indexes[i] = id.second;
1753   }
1754 }
1755 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1756 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1757     TF_Graph* graph, const TF_Buffer* graph_def,
1758     const TF_ImportGraphDefOptions* options, TF_Status* status) {
1759   GraphDef def;
1760   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1761                                        graph_def->length)) {
1762     status->status = InvalidArgument("Invalid GraphDef");
1763     return nullptr;
1764   }
1765   auto results = new TF_ImportGraphDefResults();
1766   mutex_lock l(graph->mu);
1767   GraphImportGraphDefLocked(graph, def, options, results, status);
1768   if (!status->status.ok()) {
1769     delete results;
1770     return nullptr;
1771   }
1772   return results;
1773 }
1774 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1775 void TF_GraphImportGraphDefWithReturnOutputs(
1776     TF_Graph* graph, const TF_Buffer* graph_def,
1777     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1778     int num_return_outputs, TF_Status* status) {
1779   if (num_return_outputs != options->opts.return_tensors.size()) {
1780     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1781                                      options->opts.return_tensors.size(),
1782                                      ", got ", num_return_outputs);
1783     return;
1784   }
1785   if (num_return_outputs > 0 && return_outputs == nullptr) {
1786     status->status = InvalidArgument(
1787         "'return_outputs' must be preallocated to length ", num_return_outputs);
1788     return;
1789   }
1790   GraphDef def;
1791   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1792                                        graph_def->length)) {
1793     status->status = InvalidArgument("Invalid GraphDef");
1794     return;
1795   }
1796   TF_ImportGraphDefResults results;
1797   mutex_lock l(graph->mu);
1798   GraphImportGraphDefLocked(graph, def, options, &results, status);
1799   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1800   memcpy(return_outputs, results.return_tensors.data(),
1801          num_return_outputs * sizeof(TF_Output));
1802 }
1803 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1804 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1805                             const TF_ImportGraphDefOptions* options,
1806                             TF_Status* status) {
1807   TF_ImportGraphDefResults* results =
1808       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1809   TF_DeleteImportGraphDefResults(results);
1810 }
1811 
1812 // While loop functions -------------------------------------------------------
1813 
1814 namespace {
1815 
1816 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1817 
1818 // Creates a placeholder representing an input to the cond or body graph.
1819 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1820 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1821                  TF_Output* input, TF_Status* status) {
1822   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1823   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1824   // TODO(skyewm): set placeholder shape
1825   TF_Operation* oper = TF_FinishOperation(desc, status);
1826   if (!status->status.ok()) return false;
1827   *input = {oper, 0};
1828   return true;
1829 }
1830 
1831 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1832 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
1833 // will be prepended to copied node names. `control_deps` are nodes in
1834 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1835 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1836 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1837 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1838                  tensorflow::ShapeRefiner* dst_refiner,
1839                  const TF_Output* src_inputs,
1840                  const std::vector<tensorflow::Output>& dst_inputs,
1841                  const string& prefix,
1842                  const std::vector<tensorflow::Operation>& control_deps,
1843                  const TF_Output* nodes_to_return, int nreturn_nodes,
1844                  std::vector<tensorflow::Output>* return_nodes) {
1845   DCHECK(return_nodes != nullptr);
1846   GraphDef gdef;
1847   src_graph->ToGraphDef(&gdef);
1848 
1849   tensorflow::ImportGraphDefOptions opts;
1850   opts.prefix = prefix;
1851 
1852   for (int i = 0; i < dst_inputs.size(); ++i) {
1853     opts.input_map[ToTensorId(src_inputs[i])] =
1854         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1855   }
1856   opts.skip_mapped_nodes = true;
1857 
1858   for (const tensorflow::Operation& op : control_deps) {
1859     opts.control_dependencies.push_back(op.node()->name());
1860   }
1861 
1862   for (int i = 0; i < nreturn_nodes; ++i) {
1863     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1864   }
1865 
1866   // TODO(skyewm): change to OutputTensor
1867   tensorflow::ImportGraphDefResults results;
1868   TF_RETURN_IF_ERROR(
1869       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1870 
1871   for (const auto& pair : results.return_tensors) {
1872     return_nodes->emplace_back(pair.first, pair.second);
1873   }
1874   return Status::OK();
1875 }
1876 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1877 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1878   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1879       params.cond_graph->parent == nullptr ||
1880       params.cond_graph->parent != params.body_graph->parent ||
1881       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1882       params.ninputs <= 0 || params.cond_inputs == nullptr ||
1883       params.body_inputs == nullptr || params.body_outputs == nullptr) {
1884     s->status = InvalidArgument(
1885         "TF_WhileParams must be created by successful TF_NewWhile() call");
1886     return false;
1887   }
1888   return true;
1889 }
1890 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1891 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1892   if (params.cond_output.oper == nullptr) {
1893     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1894     return false;
1895   }
1896   for (int i = 0; i < params.ninputs; ++i) {
1897     if (params.body_outputs[i].oper == nullptr) {
1898       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1899                                   "field isn't set");
1900       return false;
1901     }
1902   }
1903   if (params.name == nullptr) {
1904     s->status = InvalidArgument("TF_WhileParams `name` field is null");
1905     return false;
1906   }
1907   return true;
1908 }
1909 
1910 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1911 
FreeWhileResources(const TF_WhileParams * params)1912 void FreeWhileResources(const TF_WhileParams* params) {
1913   TF_DeleteGraph(params->cond_graph);
1914   TF_DeleteGraph(params->body_graph);
1915   delete[] params->cond_inputs;
1916   delete[] params->body_inputs;
1917   delete[] params->body_outputs;
1918 }
1919 
EmptyWhileParams()1920 TF_WhileParams EmptyWhileParams() {
1921   return {0,       nullptr, nullptr, {nullptr, 0},
1922           nullptr, nullptr, nullptr, nullptr};
1923 }
1924 
1925 }  // namespace
1926 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1927 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1928                            TF_Status* status) {
1929 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1930   status->status = tensorflow::errors::Unimplemented(
1931       "Creating while loops is not supported on mobile. File a bug at "
1932       "https://github.com/tensorflow/tensorflow/issues if this feature is "
1933       "important to you");
1934   return EmptyWhileParams();
1935 #else
1936   if (ninputs == 0) {
1937     status->status =
1938         InvalidArgument("TF_NewWhile() must be passed at least one input");
1939     return EmptyWhileParams();
1940   }
1941 
1942   TF_Graph* cond_graph = TF_NewGraph();
1943   TF_Graph* body_graph = TF_NewGraph();
1944   cond_graph->parent = g;
1945   cond_graph->parent_inputs = inputs;
1946   body_graph->parent = g;
1947   body_graph->parent_inputs = inputs;
1948 
1949   TF_Output* cond_inputs = new TF_Output[ninputs];
1950   TF_Output cond_output = {nullptr, -1};
1951   TF_Output* body_inputs = new TF_Output[ninputs];
1952   TF_Output* body_outputs = new TF_Output[ninputs];
1953   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1954   const char* name = nullptr;
1955 
1956   for (int i = 0; i < ninputs; ++i) {
1957     // TODO(skyewm): prefix names with underscore (requires some plumbing)
1958     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1959                      &cond_inputs[i], status)) {
1960       break;
1961     }
1962     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1963                      &body_inputs[i], status)) {
1964       break;
1965     }
1966   }
1967 
1968   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
1969                            body_graph, body_inputs, body_outputs, name};
1970 
1971   if (!status->status.ok()) {
1972     FreeWhileResources(&params);
1973     return EmptyWhileParams();
1974   }
1975   return params;
1976 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1977 }
1978 
1979 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1980 namespace {
1981 
1982 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)1983 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
1984                           TF_Output* outputs) {
1985   if (!ValidateInputWhileParams(*params, status)) return;
1986 
1987   TF_Graph* parent = params->cond_graph->parent;
1988   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
1989   int num_loop_vars = params->ninputs;
1990 
1991   mutex_lock l(parent->mu);
1992 
1993   // 'cond_fn' copies the cond graph into the parent graph.
1994   tensorflow::ops::CondGraphBuilderFn cond_fn =
1995       [params, parent](const tensorflow::Scope& scope,
1996                        const std::vector<tensorflow::Output>& inputs,
1997                        tensorflow::Output* output) {
1998         DCHECK_EQ(scope.graph(), &parent->graph);
1999         std::vector<tensorflow::Output> cond_output;
2000         TF_RETURN_IF_ERROR(CopyGraph(
2001             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2002             params->cond_inputs, inputs, scope.impl()->name(),
2003             scope.impl()->control_deps(), &params->cond_output,
2004             /* nreturn_nodes */ 1, &cond_output));
2005         *output = cond_output[0];
2006         return Status::OK();
2007       };
2008 
2009   // 'body_fn' copies the body graph into the parent graph.
2010   tensorflow::ops::BodyGraphBuilderFn body_fn =
2011       [params, parent, num_loop_vars](
2012           const tensorflow::Scope& scope,
2013           const std::vector<tensorflow::Output>& inputs,
2014           std::vector<tensorflow::Output>* outputs) {
2015         DCHECK_EQ(scope.graph(), &parent->graph);
2016         TF_RETURN_IF_ERROR(
2017             CopyGraph(&params->body_graph->graph, &parent->graph,
2018                       &parent->refiner, params->body_inputs, inputs,
2019                       scope.impl()->name(), scope.impl()->control_deps(),
2020                       params->body_outputs, num_loop_vars, outputs));
2021         return Status::OK();
2022       };
2023 
2024   // Create the while loop using an internal scope.
2025   tensorflow::Scope scope =
2026       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2027           .NewSubScope(params->name);
2028 
2029   const int first_new_node_id = parent->graph.num_node_ids();
2030 
2031   tensorflow::OutputList loop_outputs;
2032   status->status = tensorflow::ops::BuildWhileLoop(
2033       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2034       body_fn, params->name, &loop_outputs);
2035 
2036   // Update name_map with newly-created ops.
2037   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2038   // a bad status. Once we fix this, we may want to return early instead of
2039   // executing the following code.
2040   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2041     Node* new_node = parent->graph.FindNodeId(i);
2042     if (new_node == nullptr) continue;
2043     parent->name_map[new_node->name()] = new_node;
2044   }
2045 
2046   // Populate 'outputs'.
2047   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2048   for (int i = 0; i < loop_outputs.size(); ++i) {
2049     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2050   }
2051 }
2052 
2053 }  // namespace
2054 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2055 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2056 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2057                     TF_Output* outputs) {
2058 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2059   status->status = tensorflow::errors::Unimplemented(
2060       "Creating while loops is not supported on mobile. File a bug at "
2061       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2062       "important to you");
2063 #else
2064   // If it appears the caller created or modified `params`, don't free resources
2065   if (!ValidateConstWhileParams(*params, status)) return;
2066   TF_FinishWhileHelper(params, status, outputs);
2067   FreeWhileResources(params);
2068 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2069 }
2070 
TF_AbortWhile(const TF_WhileParams * params)2071 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2072 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2073 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2074                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2075   TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2076 }
2077 
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2078 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2079                                int ny, TF_Output* x, int nx, TF_Output* dx,
2080                                TF_Status* status, TF_Output* dy) {
2081 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2082   status->status = tensorflow::errors::Unimplemented(
2083       "Adding gradients is not supported on mobile. File a bug at "
2084       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2085       "important to you");
2086 #else
2087   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2088   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2089   std::vector<tensorflow::Output> dy_arg;
2090 
2091   {
2092     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2093     mutex_lock graph_lock(g->mu);
2094 
2095     const int first_new_node_id = g->graph.num_node_ids();
2096 
2097     string prefix_cmp;
2098     const char* child_scope_name;
2099     if (prefix == nullptr) {
2100       child_scope_name = "gradients";
2101     } else {
2102       prefix_cmp = string(prefix) + "/";
2103       // The operation should fail if the provided name prefix has already been
2104       // used in this graph
2105       for (const auto& pair : g->name_map) {
2106         const string& name = pair.first;
2107         if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2108           status->status = InvalidArgument(
2109               "prefix [", prefix,
2110               "] conflicts with existing node in the graph named [", name, "]");
2111           return;
2112         }
2113       }
2114       child_scope_name = prefix;
2115     }
2116     tensorflow::Scope scope =
2117         NewInternalScope(&g->graph, &status->status, &g->refiner)
2118             .NewSubScope(child_scope_name);
2119 
2120     if (dx != nullptr) {
2121       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2122       status->status =
2123           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2124     } else {
2125       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2126     }
2127 
2128     // Update g->name_map with the name_map from the scope, which will contain
2129     // the new gradient ops.
2130     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2131       Node* n = g->graph.FindNodeId(i);
2132       if (n == nullptr) continue;
2133 
2134       // Adding the gradients to the graph can alter the prefix to prevent
2135       // name collisions only if this prefix has not been provided explicitly
2136       // by the user. If it was provided, assert that it remained intact.
2137       if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2138         status->status = tensorflow::errors::Internal(
2139             "BUG: The gradients prefix have been unexpectedly altered when "
2140             "adding the nodes to the graph. This is a bug. Please file an "
2141             "issue at https://github.com/tensorflow/tensorflow/issues.");
2142         return;
2143       }
2144       // We have a convoluted scheme here: Using the C++ graph construction API
2145       // to add potentially many nodes to the graph without running the checks
2146       // (such as uniqueness of the names of nodes) we run with other functions
2147       // that add a node to the graph (like TF_FinishOperation).
2148       if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2149         status->status = tensorflow::errors::Internal(
2150             "BUG: The API allowed construction of a graph with duplicate node "
2151             "names (",
2152             n->name(),
2153             "). This is a bug. Please file an issue at "
2154             "https://github.com/tensorflow/tensorflow/issues.");
2155       }
2156     }
2157   }
2158 
2159   // Unpack the results from grad_outputs_arg.
2160   TFOutputsFromOutputs(dy_arg, dy);
2161 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2162 }
2163 
2164 // TF_Session functions ----------------------------------------------
2165 
TF_Session(tensorflow::Session * s,TF_Graph * g)2166 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2167     : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2168 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2169 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2170                           TF_Status* status) {
2171   Session* session;
2172   status->status = NewSession(opt->options, &session);
2173   if (status->status.ok()) {
2174     TF_Session* new_session = new TF_Session(session, graph);
2175     if (graph != nullptr) {
2176       mutex_lock l(graph->mu);
2177       graph->sessions[new_session] = "";
2178     }
2179     return new_session;
2180   } else {
2181     DCHECK_EQ(nullptr, session);
2182     return nullptr;
2183   }
2184 }
2185 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2186 TF_Session* TF_LoadSessionFromSavedModel(
2187     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2188     const char* export_dir, const char* const* tags, int tags_len,
2189     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2190 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2191 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2192 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2193   status->status = tensorflow::errors::Unimplemented(
2194       "Loading a SavedModel is not supported on mobile. File a bug at "
2195       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2196       "important to you");
2197   return nullptr;
2198 #else
2199   mutex_lock l(graph->mu);
2200   if (!graph->name_map.empty()) {
2201     status->status = InvalidArgument("Graph is non-empty.");
2202     return nullptr;
2203   }
2204 
2205   RunOptions run_options_proto;
2206   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2207                                     run_options->data, run_options->length)) {
2208     status->status = InvalidArgument("Unparseable RunOptions proto");
2209     return nullptr;
2210   }
2211 
2212   std::unordered_set<string> tag_set;
2213   for (int i = 0; i < tags_len; i++) {
2214     tag_set.insert(string(tags[i]));
2215   }
2216 
2217   tensorflow::SavedModelBundle bundle;
2218   status->status =
2219       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2220                                  export_dir, tag_set, &bundle);
2221   if (!status->status.ok()) return nullptr;
2222 
2223   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2224   // extends using GraphDefs. The Graph instance is different, but equivalent
2225   // to the one used to create the session.
2226   //
2227   // TODO(jhseu): When Session is modified to take Graphs instead of
2228   // GraphDefs, return the Graph generated in LoadSavedModel().
2229   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2230   TF_ImportGraphDefResults results;
2231   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2232                             import_opts, &results, status);
2233   TF_DeleteImportGraphDefOptions(import_opts);
2234   if (!status->status.ok()) return nullptr;
2235 
2236   if (meta_graph_def != nullptr) {
2237     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2238     if (!status->status.ok()) return nullptr;
2239   }
2240 
2241   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2242 
2243   graph->sessions[session] = "";
2244   session->last_num_graph_nodes = graph->graph.num_node_ids();
2245   return session;
2246 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2247 }
2248 
TF_CloseSession(TF_Session * s,TF_Status * status)2249 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2250   status->status = s->session->Close();
2251 }
2252 
TF_DeleteSession(TF_Session * s,TF_Status * status)2253 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2254   status->status = Status::OK();
2255   if (s == nullptr) return;
2256   TF_Graph* const graph = s->graph;
2257   if (graph != nullptr) {
2258     graph->mu.lock();
2259     graph->sessions.erase(s);
2260     const bool del = graph->delete_requested && graph->sessions.empty();
2261     graph->mu.unlock();
2262     if (del) delete graph;
2263   }
2264   delete s->session;
2265   delete s;
2266 }
2267 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2268 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2269                    const TF_Output* inputs, TF_Tensor* const* input_values,
2270                    int ninputs, const TF_Output* outputs,
2271                    TF_Tensor** output_values, int noutputs,
2272                    const TF_Operation* const* target_opers, int ntargets,
2273                    TF_Buffer* run_metadata, TF_Status* status) {
2274   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2275   // directly, instead of requiring us to serialize to a GraphDef and
2276   // call Session::Extend().
2277   if (session->extend_before_run &&
2278       !ExtendSessionGraphHelper(session, status)) {
2279     return;
2280   }
2281 
2282   TF_Run_Setup(noutputs, output_values, status);
2283 
2284   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2285   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2286   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2287   for (int i = 0; i < ninputs; ++i) {
2288     input_pairs[i].first = OutputName(inputs[i]);
2289   }
2290 
2291   // Convert from TF_Output to string names.
2292   std::vector<string> output_names(noutputs);
2293   for (int i = 0; i < noutputs; ++i) {
2294     output_names[i] = OutputName(outputs[i]);
2295   }
2296 
2297   // Convert from TF_Operation* to string names.
2298   std::vector<string> target_names(ntargets);
2299   for (int i = 0; i < ntargets; ++i) {
2300     target_names[i] = target_opers[i]->node.name();
2301   }
2302 
2303   // Actually run.
2304   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2305                 output_names, output_values, target_names, run_metadata,
2306                 status);
2307 }
2308 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2309 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2310                          int ninputs, const TF_Output* outputs, int noutputs,
2311                          const TF_Operation* const* target_opers, int ntargets,
2312                          const char** handle, TF_Status* status) {
2313   *handle = nullptr;
2314 
2315   if (session->extend_before_run &&
2316       !ExtendSessionGraphHelper(session, status)) {
2317     return;
2318   }
2319 
2320   std::vector<string> input_names(ninputs);
2321   for (int i = 0; i < ninputs; ++i) {
2322     input_names[i] = OutputName(inputs[i]);
2323   }
2324 
2325   std::vector<string> output_names(noutputs);
2326   for (int i = 0; i < noutputs; ++i) {
2327     output_names[i] = OutputName(outputs[i]);
2328   }
2329 
2330   std::vector<string> target_names(ntargets);
2331   for (int i = 0; i < ntargets; ++i) {
2332     target_names[i] = target_opers[i]->node.name();
2333   }
2334 
2335   string new_handle;
2336   status->status = session->session->PRunSetup(input_names, output_names,
2337                                                target_names, &new_handle);
2338   if (status->status.ok()) {
2339     char* buf = new char[new_handle.size() + 1];
2340     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2341     *handle = buf;
2342   }
2343 }
2344 
TF_DeletePRunHandle(const char * handle)2345 void TF_DeletePRunHandle(const char* handle) {
2346   delete[] handle;
2347   // TODO(suharshs): Free up any resources held by the partial run state.
2348 }
2349 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2350 void TF_SessionPRun(TF_Session* session, const char* handle,
2351                     const TF_Output* inputs, TF_Tensor* const* input_values,
2352                     int ninputs, const TF_Output* outputs,
2353                     TF_Tensor** output_values, int noutputs,
2354                     const TF_Operation* const* target_opers, int ntargets,
2355                     TF_Status* status) {
2356   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2357   // directly, instead of requiring us to serialize to a GraphDef and
2358   // call Session::Extend().
2359   if (session->extend_before_run &&
2360       !ExtendSessionGraphHelper(session, status)) {
2361     return;
2362   }
2363 
2364   TF_Run_Setup(noutputs, output_values, status);
2365 
2366   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2367   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2368   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2369   for (int i = 0; i < ninputs; ++i) {
2370     input_pairs[i].first = OutputName(inputs[i]);
2371   }
2372 
2373   // Convert from TF_Output to string names.
2374   std::vector<string> output_names(noutputs);
2375   for (int i = 0; i < noutputs; ++i) {
2376     output_names[i] = OutputName(outputs[i]);
2377   }
2378 
2379   // Convert from TF_Operation* to string names.
2380   std::vector<string> target_names(ntargets);
2381   for (int i = 0; i < ntargets; ++i) {
2382     target_names[i] = target_opers[i]->node.name();
2383   }
2384 
2385   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2386                 output_values, target_names, nullptr, status);
2387 }
2388 
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2389 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2390                                      TF_Tensor** result, TF_Status* status) {
2391   *result = nullptr;
2392   mutex_lock l(graph->mu);
2393   OutputTensor tensor(&output.oper->node, output.index);
2394   bool evaluated;
2395   Tensor result_tensor;
2396   status->status = EvaluateConstantTensor(
2397       tensor, graph->refiner, *graph->graph.op_registry(),
2398       graph->graph.versions().producer(), &evaluated, &result_tensor);
2399   if (evaluated) {
2400     DCHECK(status->status.ok());
2401     *result = TF_TensorFromTensor(result_tensor, &status->status);
2402     if (!status->status.ok()) evaluated = false;
2403   }
2404   return evaluated;
2405 }
2406 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2407 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2408   tensorflow::OpList op_list;
2409   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2410     status->status = InvalidArgument("Unparseable OpList");
2411     return nullptr;
2412   }
2413   status->status = Status::OK();
2414   return new TF_ApiDefMap(op_list);
2415 }
2416 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2417 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2418 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2419 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2420                      size_t text_len, TF_Status* status) {
2421 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2422   status->status = tensorflow::errors::Unimplemented(
2423       "ApiDefMap is not supported on mobile.");
2424 #else
2425   mutex_lock l(api_def_map->lock);
2426   if (api_def_map->update_docs_called) {
2427     status->status = FailedPrecondition(
2428         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2429         "called.");
2430     return;
2431   }
2432   string api_def_text(text, text_len);
2433   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2434 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2435 }
2436 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2437 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2438                            size_t name_len, TF_Status* status) {
2439 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2440   status->status = tensorflow::errors::Unimplemented(
2441       "ApiDefMap is not supported on mobile.");
2442   return nullptr;
2443 #else
2444   mutex_lock l(api_def_map->lock);
2445   if (!api_def_map->update_docs_called) {
2446     api_def_map->api_def_map.UpdateDocs();
2447     api_def_map->update_docs_called = true;
2448   }
2449   string name_str(name, name_len);
2450   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2451   if (api_def == nullptr) {
2452     return nullptr;
2453   }
2454 
2455   TF_Buffer* ret = TF_NewBuffer();
2456   status->status = MessageToBuffer(*api_def, ret);
2457   if (!status->status.ok()) {
2458     TF_DeleteBuffer(ret);
2459     return nullptr;
2460   }
2461   return ret;
2462 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2463 }
2464 
TF_GetAllRegisteredKernels(TF_Status * status)2465 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2466   tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2467   TF_Buffer* ret = TF_NewBuffer();
2468   status->status = MessageToBuffer(kernel_list, ret);
2469   if (!status->status.ok()) {
2470     TF_DeleteBuffer(ret);
2471     return nullptr;
2472   }
2473   return ret;
2474 }
2475 
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2476 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2477   tensorflow::KernelList kernel_list =
2478       tensorflow::GetRegisteredKernelsForOp(name);
2479   TF_Buffer* ret = TF_NewBuffer();
2480   status->status = MessageToBuffer(kernel_list, ret);
2481   if (!status->status.ok()) {
2482     TF_DeleteBuffer(ret);
2483     return nullptr;
2484   }
2485   return ret;
2486 }
2487 
2488 // TF_Server functions ----------------------------------------------
2489 
2490 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2491 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2492     : target(server->target()), server(std::move(server)) {}
2493 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2494 
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2495 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2496                         TF_Status* status) {
2497 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2498   status->status = tensorflow::errors::Unimplemented(
2499       "Server functionality is not supported on mobile");
2500   return nullptr;
2501 #else
2502   tensorflow::ServerDef server_def;
2503   if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2504     status->status = InvalidArgument(
2505         "Could not parse provided bytes into a ServerDef protocol buffer");
2506     return nullptr;
2507   }
2508 
2509   std::unique_ptr<tensorflow::ServerInterface> out_server;
2510   status->status = tensorflow::NewServer(server_def, &out_server);
2511   if (!status->status.ok()) return nullptr;
2512 
2513   return new TF_Server(std::move(out_server));
2514 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2515 }
2516 
TF_ServerStart(TF_Server * server,TF_Status * status)2517 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2518 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2519   status->status = tensorflow::errors::Unimplemented(
2520       "Server functionality is not supported on mobile");
2521 #else
2522   status->status = server->server->Start();
2523 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2524 }
2525 
TF_ServerStop(TF_Server * server,TF_Status * status)2526 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2527 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2528   status->status = tensorflow::errors::Unimplemented(
2529       "Server functionality is not supported on mobile");
2530 #else
2531   status->status = server->server->Stop();
2532 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2533 }
2534 
TF_ServerJoin(TF_Server * server,TF_Status * status)2535 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2536 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2537   status->status = tensorflow::errors::Unimplemented(
2538       "Server functionality is not supported on mobile");
2539 #else
2540   status->status = server->server->Join();
2541 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2542 }
2543 
TF_ServerTarget(TF_Server * server)2544 const char* TF_ServerTarget(TF_Server* server) {
2545 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2546   return nullptr;
2547 #else
2548   return server->target.c_str();
2549 #endif
2550 }
2551 
TF_DeleteServer(TF_Server * server)2552 void TF_DeleteServer(TF_Server* server) {
2553 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2554   delete server;
2555 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2556 }
2557 
TF_RegisterLogListener(void (* listener)(const char *))2558 void TF_RegisterLogListener(void (*listener)(const char*)) {
2559 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2560   tensorflow::logging::RegisterListener(listener);
2561 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2562 }
2563 
2564 }  // end extern "C"
2565